diff --git a/docs/development/dynamic_lstm.md b/docs/development/dynamic_lstm.md
new file mode 100644
index 0000000000000000000000000000000000000000..f7d24a629c02263b42f19f3fb3004e3f4c5c2193
--- /dev/null
+++ b/docs/development/dynamic_lstm.md
@@ -0,0 +1,24 @@
+Dynamic LSTM
+==================
+
+
+The DynamicLSTM in MACE is implemented for Kaldi's time delay RNN models.
+
+The following pictures explain how to fuse components into a DynamicLSTMCell.
+
+Before fusing:
+
+
+

+
+
+
+After fusing:
+
+
+

+
+
+
+For more details about LSTMNonlinear in Kaldi,
+please refer to [LstmNonlinearComponent](http://kaldi-asr.org/doc/nnet-simple-component_8h_source.html#l02164)
\ No newline at end of file
diff --git a/docs/development/imgs/DynamicLSTM.png b/docs/development/imgs/DynamicLSTM.png
new file mode 100644
index 0000000000000000000000000000000000000000..98351988b559aed3690252719181a57101bf866f
Binary files /dev/null and b/docs/development/imgs/DynamicLSTM.png differ
diff --git a/docs/development/imgs/FuseLSTM.png b/docs/development/imgs/FuseLSTM.png
new file mode 100644
index 0000000000000000000000000000000000000000..0cf89d62e71a3d9ff4bf0818556d18604372178c
Binary files /dev/null and b/docs/development/imgs/FuseLSTM.png differ
diff --git a/mace/benchmark/statistics.cc b/mace/benchmark/statistics.cc
index 6ea69be0a67dec96c553ee199c8d42b8fe535ee8..9af7fcb30bf343f82232a5d240b5b78536d3949c 100644
--- a/mace/benchmark/statistics.cc
+++ b/mace/benchmark/statistics.cc
@@ -131,6 +131,9 @@ int64_t StatMACs(const std::string &op_type,
output_shape.end(),
1,
std::multiplies());
+ } else if (op_type == "DynamicLSTM") {
+ macs = output_shape[0] * (filter_shape[0] * filter_shape[1]
+ + output_shape[1] * filter_shape[0] / 4);
}
return macs;
}
diff --git a/mace/ops/arm/fp32/gemv.cc b/mace/ops/arm/fp32/gemv.cc
index cd0f607fd63f16bb5c99ea0a369dc8423a6bf358..7caa0b5b23d1a9b30d81ce94126bfc2a1a5b82d6 100644
--- a/mace/ops/arm/fp32/gemv.cc
+++ b/mace/ops/arm/fp32/gemv.cc
@@ -48,8 +48,8 @@ MaceStatus Gemv::Compute(const OpContext *context,
Tensor *output) {
MACE_UNUSED(context);
- MACE_CHECK(output->size() == batch * lhs_height,
- "Need resize output tensor before call gemv.");
+ MACE_CHECK(output->size() >= batch * lhs_height,
+ "Output buffer is not large enough for computing gemv.");
Tensor::MappingGuard lhs_guard(lhs);
Tensor::MappingGuard rhs_guard(rhs);
diff --git a/mace/ops/common/lstm.cc b/mace/ops/common/lstm.cc
new file mode 100644
index 0000000000000000000000000000000000000000..beea3f5b8081584b219cd6c662c4451dfe4cc223
--- /dev/null
+++ b/mace/ops/common/lstm.cc
@@ -0,0 +1,75 @@
+// Copyright 2019 The MACE Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Details are in
+// http://kaldi-asr.org/doc/nnet-simple-component_8h_source.html#l02164
+
+#include "mace/ops/common/lstm.h"
+#include "mace/utils/math.h"
+
+namespace mace {
+namespace ops {
+
+void LSTMNonlinearKernel(const float *input_data,
+ const float *prev_data,
+ const float *scale_data,
+ const float *params_data,
+ bool embed_scales,
+ index_t params_stride,
+ index_t cell_dim,
+ float *output_cell,
+ float *output_data) {
+ float i_scale = (embed_scales && scale_data) ? scale_data[0] : 1.0f;
+ float f_scale = (embed_scales && scale_data) ? scale_data[1] : 1.0f;
+ float o_scale = (embed_scales && scale_data) ? scale_data[2] : 1.0f;
+
+ if (prev_data == nullptr) {
+#pragma omp parallel for schedule(runtime)
+ for (int c = 0; c < cell_dim; ++c) {
+ float i_part = input_data[c];
+ float c_part = input_data[c + 2 * cell_dim];
+ float o_part = input_data[c + 3 * cell_dim];
+ float w_oc = params_data[c + params_stride * 2];
+ float i_t = ScalarSigmoid(i_part);
+ float c_t = i_t * i_scale * std::tanh(c_part);
+ float o_t = ScalarSigmoid(o_part + w_oc * c_t);
+ float m_t = o_t * o_scale * std::tanh(c_t);
+ output_cell[c] = c_t;
+ output_data[c] = m_t;
+ }
+ } else {
+#pragma omp parallel for schedule(runtime)
+ for (int c = 0; c < cell_dim; ++c) {
+ float i_part = input_data[c];
+ float f_part = input_data[c + cell_dim];
+ float c_part = input_data[c + 2 * cell_dim];
+ float o_part = input_data[c + 3 * cell_dim];
+ float c_prev = prev_data[c];
+ float w_ic = params_data[c];
+ float w_fc = params_data[c + params_stride];
+ float w_oc = params_data[c + params_stride * 2];
+ float i_t = ScalarSigmoid(i_part + w_ic * c_prev);
+ float f_t = ScalarSigmoid(f_part + w_fc * c_prev);
+ float c_t =
+ f_t * f_scale * c_prev + i_t * i_scale * std::tanh(c_part);
+ float o_t = ScalarSigmoid(o_part + w_oc * c_t);
+ float m_t = o_t * o_scale * std::tanh(c_t);
+ output_cell[c] = c_t;
+ output_data[c] = m_t;
+ }
+ }
+}
+
+} // namespace ops
+} // namespace mace
diff --git a/mace/ops/common/lstm.h b/mace/ops/common/lstm.h
new file mode 100644
index 0000000000000000000000000000000000000000..b835386041b6ba86f13818fe4f57c1efb1dff15d
--- /dev/null
+++ b/mace/ops/common/lstm.h
@@ -0,0 +1,37 @@
+// Copyright 2019 The MACE Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef MACE_OPS_COMMON_LSTM_H_
+#define MACE_OPS_COMMON_LSTM_H_
+
+#include "mace/core/types.h"
+namespace mace {
+namespace ops {
+
+void LSTMNonlinearKernel(const float *input_data,
+ const float *prev_data,
+ const float *scale_data,
+ const float *params_data,
+ bool embed_scales,
+ index_t params_stride,
+ index_t cell_dim,
+ float *output_cell,
+ float *output_data);
+
+
+} // namespace ops
+} // namespace mace
+
+#endif // MACE_OPS_COMMON_LSTM_H_
+
diff --git a/mace/ops/dynamic_lstm.cc b/mace/ops/dynamic_lstm.cc
new file mode 100644
index 0000000000000000000000000000000000000000..7fe93f21d6b7831bfe5fba3d21200a21923cdc2e
--- /dev/null
+++ b/mace/ops/dynamic_lstm.cc
@@ -0,0 +1,332 @@
+// Copyright 2018 The MACE Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// This Op is for Fused-LstmNonlinearComponent
+// with prev cell states as inputs in Kaldi.
+// http://kaldi-asr.org/doc/nnet-simple-component_8h_source.html#l02164
+// More details are in docs/development/dynamic_lstm.md
+
+#include
+#include
+
+#include "mace/core/operator.h"
+#include "mace/ops/common/lstm.h"
+
+#ifdef MACE_ENABLE_NEON
+#include
+#include "mace/ops/arm/fp32/gemv.h"
+#else
+#include "mace/ops/ref/gemv.h"
+#endif // MACE_ENABLE_NEON
+
+namespace mace {
+namespace ops {
+
+template
+class DynamicLSTMOp;
+
+template
+class DynamicLSTMOp : public Operation {
+ public:
+ explicit DynamicLSTMOp(OpConstructContext *context)
+ : Operation(context),
+ prev_out_delay_(
+ Operation::GetOptionalArg("prev_out_delay", 0)),
+ prev_cell_delay_(
+ Operation::GetOptionalArg("prev_cell_delay", 0)),
+ prev_out_offset_(Operation::GetOptionalArg("prev_out_offset", 0)),
+ prev_out_dim_(Operation::GetOptionalArg("prev_out_dim", 0)),
+ prev_cell_dim_(Operation::GetOptionalArg("prev_cell_dim", 0)),
+ has_bias_a_(Operation::GetOptionalArg("bias_a", 1)),
+ has_bias_b_(Operation::GetOptionalArg("bias_b", 1)),
+ scale_(Operation::GetOptionalArg("scale", 1.0f)) {}
+
+ void UpdateCell(float *cell_data,
+ const index_t cell_dim,
+ const float scale) {
+ if (std::abs(scale - 1.f) < 1e-6)
+ return;
+ const index_t rounds = cell_dim / 4;
+#pragma omp parallel for schedule(runtime)
+ for (index_t i = 0; i < rounds * 4; i += 4) {
+#ifdef MACE_ENABLE_NEON
+ float32x4_t in_vec = vld1q_f32(cell_data + i);
+ float32x4_t scale_vec = vdupq_n_f32(scale);
+ in_vec = vmulq_f32(in_vec, scale_vec);
+ vst1q_f32(cell_data + i, in_vec);
+#else
+ for (int j = 0; j < 4; ++j) {
+ cell_data[i + j] *= scale;
+ }
+#endif
+ }
+ for (index_t i = rounds * 4; i < cell_dim; ++i) {
+ cell_data[i] *= scale;
+ }
+ }
+
+ void CopyAndUpdateCell(float *src_data,
+ const index_t cell_dim,
+ const float scale,
+ float *cell_data) {
+ if (std::abs(scale - 1.f) < 1e-6) {
+ memcpy(cell_data, src_data, cell_dim * sizeof(float));
+ return;
+ }
+
+ const index_t rounds = cell_dim / 4;
+#pragma omp parallel for schedule(runtime)
+ for (index_t i = 0; i < rounds * 4; i += 4) {
+#ifdef MACE_ENABLE_NEON
+ float32x4_t in_vec = vld1q_f32(src_data + i);
+ float32x4_t scale_vec = vdupq_n_f32(scale);
+ in_vec = vmulq_f32(in_vec, scale_vec);
+ vst1q_f32(cell_data + i, in_vec);
+#else
+ for (int j = 0; j < 4; ++j) {
+ cell_data[i + j] = src_data[i + j] * scale;
+ }
+#endif
+ }
+ for (index_t i = rounds * 4; i < cell_dim; ++i) {
+ cell_data[i] = src_data[i] * scale;
+ }
+ }
+
+ MaceStatus Run(OpContext *context) override {
+ MACE_UNUSED(context);
+ int max_input_num = 4;
+ MACE_CHECK(this->InputSize() >= max_input_num,
+ "DynamicLSTM has at least four inputs.");
+ MACE_CHECK(prev_cell_delay_ < 0 && prev_out_delay_ < 0);
+ MACE_CHECK(prev_out_dim_ > 0 && prev_cell_dim_ > 0);
+ const Tensor *input = this->Input(INPUT);
+ const Tensor *weights_a = this->Input(WEIGHTS_A);
+ const Tensor *lstm_params = this->Input(PARAMS);
+ const Tensor *weights_b = this->Input(WEIGHTS_B);
+ if (has_bias_a_) {
+ max_input_num++;
+ MACE_CHECK(this->InputSize() >= max_input_num,
+ "The first affine needs a bias input.");
+ }
+ const Tensor *bias_a = has_bias_a_ ?
+ this->Input(max_input_num - 1) :
+ nullptr;
+ if (has_bias_b_) {
+ max_input_num++;
+ MACE_CHECK(this->InputSize() >= max_input_num,
+ "The second affine needs a bias input.");
+ }
+ const Tensor *bias_b = has_bias_b_ ?
+ this->Input(max_input_num - 1) :
+ nullptr;
+
+ const index_t input_rank = input->dim_size();
+ MACE_CHECK(input_rank >= 2,
+ "Dynamic LSTM Cell's input dim size should be >= 2.");
+ const std::vector &input_shape = input->shape();
+ const index_t batch =
+ std::accumulate(input_shape.begin(), input_shape.end() - 2, 1,
+ std::multiplies());
+ const index_t chunk = input_shape[input_rank - 2];
+ const index_t input_dim = input_shape[input_rank - 1];
+
+ const index_t affine_a_in_dim = input_dim + prev_out_dim_;
+ const index_t affine_a_out_dim = weights_a->dim(0);
+ const index_t affine_a_depth = weights_a->dim(1);
+ MACE_CHECK(affine_a_in_dim == affine_a_depth)
+ << "affine_a's input_dim:" << affine_a_in_dim
+ << "!=" << "affine_a's weights' depth:" << affine_a_depth << std::endl;
+
+ const index_t lstm_input_dim = affine_a_out_dim + prev_cell_dim_;
+ const index_t lstm_cell_dim = lstm_input_dim / 5;
+ const index_t params_stride = lstm_params->dim(1);
+ MACE_CHECK(lstm_input_dim == (lstm_cell_dim * 5));
+ MACE_CHECK(lstm_params->dim(0) == 3 &&
+ params_stride == lstm_cell_dim && lstm_cell_dim == prev_cell_dim_)
+ << "lstm params rows:" << lstm_params->dim(0)
+ << "params_stride:"<< params_stride
+ << "!=" << "cell_dim:"<< lstm_cell_dim << std::endl;
+ const index_t affine_b_out_dim = weights_b->dim(0);
+ const index_t affine_b_depth = weights_b->dim(1);
+ const index_t affine_b_in_dim = lstm_cell_dim;
+ MACE_CHECK(affine_b_in_dim == affine_b_depth)
+ << "affine_b's input_dim:" << affine_b_in_dim
+ << "!=" << "affine_b's weights' depth:" << affine_b_depth << std::endl;
+
+ const index_t output_dim = affine_b_out_dim;
+ MACE_CHECK(prev_out_offset_ + prev_out_dim_ <= output_dim);
+
+ const index_t affine_a_in_size =
+ PadAlignSize(affine_a_in_dim * sizeof(float));
+ const index_t affine_a_out_size =
+ PadAlignSize(affine_a_out_dim * sizeof(float));
+ const index_t affine_b_in_size =
+ PadAlignSize(affine_b_in_dim * sizeof(float));
+ const index_t affine_b_out_size =
+ PadAlignSize(affine_b_out_dim * sizeof(float));
+
+ const int out_buf_chunk = abs(prev_out_delay_);
+ const int cell_buf_chunk = abs(prev_cell_delay_);
+ const index_t out_buf_size =
+ PadAlignSize(out_buf_chunk * prev_out_dim_ * sizeof(float));
+ const index_t cell_buf_size =
+ PadAlignSize(cell_buf_chunk * prev_cell_dim_ * sizeof(float));
+ ScratchBuffer *scratch = context->device()->scratch_buffer();
+ scratch->Rewind();
+ scratch->GrowSize(affine_a_in_size + affine_a_out_size
+ + affine_b_in_size + affine_b_out_size
+ + out_buf_size + cell_buf_size);
+
+ Tensor prev_out(scratch->Scratch(out_buf_size), DT_FLOAT);
+ prev_out.Reshape({out_buf_chunk, prev_out_dim_});
+ float *prev_out_data = prev_out.mutable_data();
+
+ Tensor prev_cell(scratch->Scratch(cell_buf_size), DT_FLOAT);
+ prev_cell.Reshape({cell_buf_chunk, prev_cell_dim_});
+ float *prev_cell_data = prev_cell.mutable_data();
+
+ Tensor affine_a_in(scratch->Scratch(affine_a_in_size), DT_FLOAT);
+ affine_a_in.Reshape({1, affine_a_in_dim});
+ float *affine_a_in_data = affine_a_in.mutable_data();
+
+ Tensor affine_a_out(scratch->Scratch(affine_a_out_size), DT_FLOAT);
+ affine_a_out.Reshape({1, affine_a_out_dim});
+ float *affine_a_out_data = affine_a_out.mutable_data();
+
+ Tensor affine_b_in(scratch->Scratch(affine_b_in_size), DT_FLOAT);
+ affine_b_in.Reshape({1, affine_b_in_dim});
+ float *affine_b_in_data = affine_b_in.mutable_data();
+
+ Tensor affine_b_out(scratch->Scratch(affine_b_out_size), DT_FLOAT);
+ affine_b_out.Reshape({1, affine_b_out_dim});
+ float *affine_b_out_data = affine_b_out.mutable_data();
+
+ Tensor *output = this->Output(OUTPUT);
+
+ std::vector output_shape = input->shape();
+ output_shape[1] = output_dim;
+
+ MACE_RETURN_IF_ERROR(output->Resize(output_shape));
+
+ Tensor::MappingGuard input_guard(input);
+ Tensor::MappingGuard lstm_params_guard(lstm_params);
+ Tensor::MappingGuard output_guard(output);
+ const float *input_data = input->data();
+ const float *lstm_params_data = lstm_params->data();
+ float *output_data = output->mutable_data();
+
+ for (int b = 0; b < batch; ++b) {
+ int prev_out_idx = prev_out_delay_;
+ int prev_cell_idx = prev_cell_delay_;
+ prev_cell.Clear();
+ prev_out.Clear();
+ affine_a_in.Clear();
+ affine_a_out.Clear();
+ affine_b_in.Clear();
+ affine_b_out.Clear();
+ for (int i = 0; i < chunk; ++i) {
+ // Append
+ memcpy(affine_a_in_data, input_data, input_dim * sizeof(float));
+ if (prev_out_idx >= 0) {
+ memcpy(affine_a_in_data + input_dim,
+ prev_out_data + prev_out_idx % out_buf_chunk * prev_out_dim_,
+ prev_out_dim_ * sizeof(float));
+ }
+ // Affine
+ gemv_.Compute(context,
+ weights_a,
+ &affine_a_in,
+ bias_a,
+ 1,
+ affine_a_out_dim,
+ affine_a_depth,
+ false,
+ false,
+ &affine_a_out);
+ // Prepare LSTMNonlinear input and output pointer
+ float *prev_cell_ptr =
+ prev_cell_idx < 0 ? nullptr :
+ prev_cell_data + prev_cell_idx % cell_buf_chunk * prev_cell_dim_;
+ float *curr_cell_ptr =
+ prev_cell_data + i % cell_buf_chunk * prev_cell_dim_;
+ // LSTMNonlinear
+ LSTMNonlinearKernel(affine_a_out_data,
+ prev_cell_ptr,
+ nullptr,
+ lstm_params_data,
+ false,
+ params_stride,
+ lstm_cell_dim,
+ curr_cell_ptr,
+ affine_b_in_data);
+ UpdateCell(curr_cell_ptr, prev_cell_dim_, scale_);
+ // Affine
+ gemv_.Compute(context,
+ weights_b,
+ &affine_b_in,
+ bias_b,
+ 1,
+ affine_b_out_dim,
+ affine_b_depth,
+ false,
+ false,
+ &affine_b_out);
+ // Output
+ memcpy(output_data,
+ affine_b_out_data,
+ output_dim * sizeof(float));
+ // Update
+ float *curr_out_ptr = prev_out_data + i % out_buf_chunk * prev_out_dim_;
+ CopyAndUpdateCell(affine_b_out_data + prev_out_offset_,
+ prev_out_dim_,
+ scale_,
+ curr_out_ptr);
+ input_data += input_dim;
+ output_data += output_dim;
+ prev_out_idx++;
+ prev_cell_idx++;
+ }
+ }
+
+ return MaceStatus::MACE_SUCCESS;
+ }
+
+ private:
+ int prev_out_delay_;
+ int prev_cell_delay_;
+ int prev_out_offset_;
+ int prev_out_dim_;
+ int prev_cell_dim_;
+ int has_bias_a_;
+ int has_bias_b_;
+ float scale_;
+
+#ifdef MACE_ENABLE_NEON
+ arm::fp32::Gemv gemv_;
+#else
+ ref::Gemv gemv_;
+#endif // MACE_ENABLE_NEON
+
+ MACE_OP_INPUT_TAGS(INPUT, WEIGHTS_A, PARAMS, WEIGHTS_B);
+ MACE_OP_OUTPUT_TAGS(OUTPUT);
+};
+
+void RegisterDynamicLSTM(OpRegistryBase *op_registry) {
+ MACE_REGISTER_OP(op_registry, "DynamicLSTM", DynamicLSTMOp,
+ DeviceType::CPU, float);
+}
+
+} // namespace ops
+} // namespace mace
diff --git a/mace/ops/dynamic_lstm_benchmark.cc b/mace/ops/dynamic_lstm_benchmark.cc
new file mode 100644
index 0000000000000000000000000000000000000000..00cd11bfd28fe4108847021dd11a40a237d15125
--- /dev/null
+++ b/mace/ops/dynamic_lstm_benchmark.cc
@@ -0,0 +1,124 @@
+// Copyright 2018 The MACE Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "mace/benchmark/statistics.h"
+#include "mace/core/testing/test_benchmark.h"
+#include "mace/ops/lstmcell_test_util.h"
+#include "mace/ops/ops_test_util.h"
+
+namespace mace {
+namespace ops {
+namespace test {
+
+namespace {
+template
+void DynamicLSTM(int iters,
+ int chunk,
+ int input_dim,
+ int output_dim,
+ int cell_dim,
+ int prev_out_dim,
+ int delay) {
+ mace::testing::StopTiming();
+
+ OpsTestNet net;
+ MACE_CHECK(prev_out_dim <= output_dim);
+ const int weights_a_rows = 4 * cell_dim;
+ const int weights_a_cols = input_dim + prev_out_dim;
+ const int bias_a_rows = weights_a_rows;
+
+ const int weights_b_rows = output_dim;
+ const int weights_b_cols = cell_dim;
+ const int bias_b_rows = weights_b_rows;
+
+ // Add input data
+ net.AddRandomInput("Input", {chunk, input_dim});
+ net.AddRandomInput("Weight_A",
+ {weights_a_rows, weights_a_cols},
+ true);
+ net.AddRandomInput("Params",
+ {3, cell_dim},
+ true);
+ net.AddRandomInput("Weight_B",
+ {weights_b_rows, weights_b_cols},
+ true);
+ net.AddRandomInput("Bias_A", {bias_a_rows}, true);
+ net.AddRandomInput("Bias_B", {bias_b_rows}, true);
+
+ if (D == DeviceType::CPU) {
+ OpDefBuilder("DynamicLSTM", "DynamicLSTMTest")
+ .Input("Input")
+ .Input("Weight_A")
+ .Input("Params")
+ .Input("Weight_B")
+ .Input("Bias_A")
+ .Input("Bias_B")
+ .Output("Output")
+ .AddIntArg("prev_out_delay", -delay)
+ .AddIntArg("prev_cell_delay", -delay)
+ .AddIntArg("prev_out_dim", prev_out_dim)
+ .AddIntArg("prev_cell_dim", cell_dim)
+ .AddIntArg("T", static_cast(DataTypeToEnum::value))
+ .Finalize(net.NewOperatorDef());
+ } else {
+ MACE_NOT_IMPLEMENTED;
+ }
+
+ // Warm-up
+ for (int i = 0; i < 5; ++i) {
+ net.RunOp(D);
+ }
+ net.Sync();
+
+ mace::testing::StartTiming();
+ while (iters--) {
+ net.RunOp(D);
+ }
+ net.Sync();
+}
+} // namespace
+
+#define MACE_BM_DYNAMIC_LSTM_MACRO( \
+ N, ID, OD, CD, POD, DELAY, TYPE, DEVICE) \
+ static void \
+ MACE_BM_DYNAMIC_LSTM_##N##_##ID##_##OD##_##CD##_##POD##_##DELAY##_##TYPE\
+ ##_##DEVICE( \
+ int iters) { \
+ int64_t wa_size = 4 * CD * (ID + POD); \
+ int64_t wb_size = OD * CD; \
+ int64_t prev_size = DELAY * (POD + CD); \
+ int64_t in_out_size = N * (ID + OD); \
+ int64_t bias_size = 4 * CD + OD; \
+ const int64_t macs = static_cast(iters) * \
+ mace::benchmark::StatMACs("DynamicLSTM", {4 * CD, ID + POD}, {N, OD});\
+ const int64_t tot = static_cast(iters) * (in_out_size + prev_size\
+ + wa_size + wb_size + bias_size); \
+ mace::testing::MacsProcessed(macs); \
+ mace::testing::BytesProcessed(tot * (sizeof(TYPE))); \
+ DynamicLSTM(iters, N, ID, OD, CD, POD, DELAY); \
+ } \
+ MACE_BENCHMARK( \
+ MACE_BM_DYNAMIC_LSTM_##N##_##ID##_##OD##_##CD##_##POD##_##DELAY \
+ ##_##TYPE##_##DEVICE)
+
+#define MACE_BM_DYNAMIC_LSTM(N, ID, OD, CD, POD, DELAY) \
+ MACE_BM_DYNAMIC_LSTM_MACRO(N, ID, OD, CD, POD, DELAY, float, CPU);
+
+MACE_BM_DYNAMIC_LSTM(50, 184, 128, 184, 64, 3);
+MACE_BM_DYNAMIC_LSTM(50, 64, 256, 64, 128, 3);
+MACE_BM_DYNAMIC_LSTM(80, 64, 256, 128, 64, 3);
+
+} // namespace test
+} // namespace ops
+} // namespace mace
diff --git a/mace/ops/lstm_nonlinear.cc b/mace/ops/lstm_nonlinear.cc
new file mode 100644
index 0000000000000000000000000000000000000000..745c4d79674c6e2becc2eb49b2d855a2819a0e15
--- /dev/null
+++ b/mace/ops/lstm_nonlinear.cc
@@ -0,0 +1,113 @@
+// Copyright 2018 The MACE Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// This Op is for LstmNonlinearComponent in Kaldi.
+// http://kaldi-asr.org/doc/nnet-simple-component_8h_source.html#l02164
+
+#include
+#include
+
+#include "mace/core/operator.h"
+#include "mace/ops/common/lstm.h"
+
+namespace mace {
+namespace ops {
+
+template
+class LSTMNonlinearOp;
+
+template
+class LSTMNonlinearOp : public Operation {
+ public:
+ explicit LSTMNonlinearOp(OpConstructContext *context)
+ : Operation(context) {}
+
+ MaceStatus Run(OpContext *context) override {
+ MACE_UNUSED(context);
+ const Tensor *input = this->Input(INPUT);
+ MACE_CHECK(this->InputSize() >= 2,
+ "LSTMNonlinear should have at least 2 inputs.");
+ const Tensor *params = this->Input(PARAMS);
+ Tensor *output = this->Output(OUTPUT);
+
+ MACE_CHECK(input->dim_size() >= 2)
+ << "The input dim size should >= 2";
+ MACE_CHECK(params->dim_size() == 2)
+ << "The params dim size should be 2";
+ return Compute(input, params, output);
+ }
+
+ MaceStatus Compute(const Tensor *input,
+ const Tensor *params,
+ Tensor *output) {
+ const std::vector &input_shape = input->shape();
+ const std::vector ¶ms_shape = params->shape();
+
+ const index_t num_rows =
+ std::accumulate(input_shape.begin(), input_shape.end() - 1, 1,
+ std::multiplies());
+ index_t rank = input->dim_size();
+ const index_t input_cols = input_shape[rank - 1];
+ const index_t cell_dim = input_cols / 5;
+ bool embed_scales = input_cols == cell_dim * 5 + 3;
+ const index_t params_stride = params_shape[1];
+
+ MACE_CHECK(input_cols == (cell_dim * 5) || embed_scales);
+ MACE_CHECK(params_shape[0] == 3 && params_shape[1] == cell_dim);
+
+ const index_t output_dim = cell_dim * 2;
+ std::vector output_shape = input->shape();
+ output_shape[rank - 1] = output_dim;
+ MACE_RETURN_IF_ERROR(output->Resize(output_shape));
+
+ Tensor::MappingGuard input_guard(input);
+ Tensor::MappingGuard params_guard(params);
+ Tensor::MappingGuard output_guard(output);
+ const float *input_data = input->data();
+ const float *params_data = params->data();
+ float *output_data = output->mutable_data();
+#pragma omp parallel for schedule(runtime)
+ for (int r = 0; r < num_rows; ++r) {
+ const float *input_row = input_data + r * input_cols;
+ const float *prev_row = input_row + 4 * cell_dim;
+ const float *scale_data =
+ embed_scales ? prev_row + cell_dim : nullptr;
+ float *output_cell = output_data + r * output_dim;
+ float *output_row = output_cell + cell_dim;
+ LSTMNonlinearKernel(input_row,
+ prev_row,
+ scale_data,
+ params_data,
+ embed_scales,
+ params_stride,
+ cell_dim,
+ output_cell,
+ output_row);
+ }
+
+ return MaceStatus::MACE_SUCCESS;
+ }
+
+ protected:
+ MACE_OP_INPUT_TAGS(INPUT, PARAMS);
+ MACE_OP_OUTPUT_TAGS(OUTPUT);
+};
+
+void RegisterLSTMNonlinear(OpRegistryBase *op_registry) {
+ MACE_REGISTER_OP(op_registry, "LSTMNonlinearOp", LSTMNonlinearOp,
+ DeviceType::CPU, float);
+}
+
+} // namespace ops
+} // namespace mace
diff --git a/mace/ops/lstm_nonlinear_benchmark.cc b/mace/ops/lstm_nonlinear_benchmark.cc
new file mode 100644
index 0000000000000000000000000000000000000000..6ca881c57f4fa768368a6977fbe89ac276dd83c1
--- /dev/null
+++ b/mace/ops/lstm_nonlinear_benchmark.cc
@@ -0,0 +1,85 @@
+// Copyright 2018 The MACE Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "mace/core/testing/test_benchmark.h"
+#include "mace/ops/lstmcell_test_util.h"
+#include "mace/ops/ops_test_util.h"
+
+namespace mace {
+namespace ops {
+namespace test {
+
+namespace {
+template
+void LSTMNonlinear(int iters,
+ int batch,
+ int input_dim) {
+ mace::testing::StopTiming();
+
+ OpsTestNet net;
+
+ int cell_dim = input_dim / 5;
+
+ // Add input data
+ net.AddRandomInput("Input", {batch, input_dim});
+ net.AddRandomInput("Params",
+ {3, cell_dim},
+ true);
+ if (D == DeviceType::CPU) {
+ OpDefBuilder("LSTMNonlinear", "LSTMNonlinearTest")
+ .Input("Input")
+ .Input("Params")
+ .Output("Output")
+ .AddIntArg("T", static_cast(DataTypeToEnum::value))
+ .Finalize(net.NewOperatorDef());
+ } else {
+ MACE_NOT_IMPLEMENTED;
+ }
+
+ // Warm-up
+ for (int i = 0; i < 5; ++i) {
+ net.RunOp(D);
+ }
+ net.Sync();
+
+ mace::testing::StartTiming();
+ while (iters--) {
+ net.RunOp(D);
+ }
+ net.Sync();
+}
+} // namespace
+
+#define MACE_BM_LSTM_NONLIN_MACRO(N, IN_DIM, TYPE, DEVICE) \
+ static void \
+ MACE_BM_LSTM_NONLIN_##N##_##IN_DIM##_##TYPE##_##DEVICE(\
+ int iters) { \
+ const int64_t tot = \
+ static_cast(iters) * (N * IN_DIM + 3 * (IN_DIM / 5));\
+ mace::testing::BytesProcessed(tot * (sizeof(TYPE))); \
+ LSTMNonlinear(iters, N, IN_DIM); \
+ } \
+ MACE_BENCHMARK( \
+ MACE_BM_LSTM_NONLIN_##N##_##IN_DIM##_##TYPE##_##DEVICE)
+
+#define MACE_BM_LSTM_NONLIN(N, IN_DIM) \
+ MACE_BM_LSTM_NONLIN_MACRO(N, IN_DIM, float, CPU);
+
+MACE_BM_LSTM_NONLIN(50, 200);
+MACE_BM_LSTM_NONLIN(50, 920);
+MACE_BM_LSTM_NONLIN(80, 640);
+
+} // namespace test
+} // namespace ops
+} // namespace mace
diff --git a/mace/ops/opencl/buffer/softmax.h b/mace/ops/opencl/buffer/softmax.h
index db3f2800cefc276fda229cb7d27950f79c19fcea..3ab6a7cef1bd1d760ea70e1409f687d664f51996 100644
--- a/mace/ops/opencl/buffer/softmax.h
+++ b/mace/ops/opencl/buffer/softmax.h
@@ -32,12 +32,16 @@ namespace buffer {
template
class SoftmaxKernel : public OpenCLSoftmaxKernel {
public:
+ explicit SoftmaxKernel(bool use_log)
+ : use_log_(use_log) {}
+
MaceStatus Compute(
OpContext *context,
const Tensor *logits,
Tensor *output) override;
private:
+ bool use_log_;
cl::Kernel kernel_;
uint32_t kwg_size_;
std::vector input_shape_;
@@ -88,6 +92,7 @@ MaceStatus SoftmaxKernel::Compute(
built_options.emplace("-DIN_DATA_TYPE=" + DtToCLDt(logits->dtype()));
built_options.emplace("-DOUT_DATA_TYPE=" + DtToCLDt(dt));
built_options.emplace("-DDATA_TYPE=" + DtToUpCompatibleCLDt(dt));
+ if (use_log_) built_options.emplace("-DUSE_LOG");
MACE_RETURN_IF_ERROR(runtime->BuildKernel("softmax_buffer", kernel_name,
built_options, &kernel_));
diff --git a/mace/ops/opencl/cl/softmax.cl b/mace/ops/opencl/cl/softmax.cl
index 39f8c89fe6e7b94bcfb803e1e6da1c562c86f320..c4babfd23be9ae7a23a12ee74c83c1b16d5dc8a5 100644
--- a/mace/ops/opencl/cl/softmax.cl
+++ b/mace/ops/opencl/cl/softmax.cl
@@ -73,13 +73,25 @@ __kernel void softmax(OUT_OF_RANGE_PARAMS
switch(exceeded) {
case 1:
data.z = native_exp(data.z) / sum;
+#ifdef USE_LOG
+ data.z = native_log(data.z);
+#endif
case 2:
data.y = native_exp(data.y) / sum;
+#ifdef USE_LOG
+ data.y = native_log(data.y);
+#endif
case 3:
data.x = native_exp(data.x) / sum;
+#ifdef USE_LOG
+ data.x = native_log(data.x);
+#endif
break;
default:
data = native_exp(data) / sum;
+#ifdef USE_LOG
+ data = native_log(data);
+#endif
}
WRITE_IMAGET(output, (int2)(pos, hb_idx), data);
diff --git a/mace/ops/opencl/cl/softmax_buffer.cl b/mace/ops/opencl/cl/softmax_buffer.cl
index 2a96a237d91c9d05fe7516527a331c681f62a3be..78fae30ff2e1173fb75c5eca5effee095d35e8b8 100644
--- a/mace/ops/opencl/cl/softmax_buffer.cl
+++ b/mace/ops/opencl/cl/softmax_buffer.cl
@@ -75,14 +75,26 @@ __kernel void softmax(BUFFER_OUT_OF_RANGE_PARAMS
switch(remain_chan) {
case 3:
output[offset + 2] = native_exp(CONVERT(input[offset + 2]) - max_value) / sum;
+#ifdef USE_LOG
+ output[offset + 2] = native_log(output[offset + 2]);
+#endif
case 2:
output[offset + 1] = native_exp(CONVERT(input[offset + 1]) - max_value) / sum;
+#ifdef USE_LOG
+ output[offset + 1] = native_log(output[offset + 1]);
+#endif
case 1:
output[offset] = native_exp(CONVERT(input[offset]) - max_value) / sum;
+#ifdef USE_LOG
+ output[offset] = native_log(output[offset]);
+#endif
}
} else {
data = CONVERT4(vload4(0, input + offset));
data = native_exp(data - max_value) / sum;
+#ifdef USE_LOG
+ data = native_log(data)
+#endif
VSTORE4(CONVERT_TO(data, OUT_DATA_TYPE4), output, offset);
}
}
diff --git a/mace/ops/opencl/image/softmax.h b/mace/ops/opencl/image/softmax.h
index 2bbf1aa31d84d4d011bd5ab1875779f51ea236f3..3aa84bb5091066bff8565d3428fca7ebe4badafd 100644
--- a/mace/ops/opencl/image/softmax.h
+++ b/mace/ops/opencl/image/softmax.h
@@ -59,12 +59,16 @@ inline std::vector LocalWS(OpenCLRuntime *runtime,
template
class SoftmaxKernel : public OpenCLSoftmaxKernel {
public:
+ explicit SoftmaxKernel(bool use_log)
+ : use_log_(use_log) {}
+
MaceStatus Compute(
OpContext *context,
const Tensor *logits,
Tensor *output) override;
private:
+ bool use_log_;
cl::Kernel kernel_;
uint32_t kwg_size_;
std::vector input_shape_;
@@ -114,6 +118,8 @@ MaceStatus SoftmaxKernel::Compute(
auto dt = DataTypeToEnum::value;
built_options.emplace("-DDATA_TYPE=" + DtToUpCompatibleCLDt(dt));
built_options.emplace("-DCMD_DATA_TYPE=" + DtToUpCompatibleCLCMDDt(dt));
+ if (use_log_)
+ built_options.emplace("-DUSE_LOG");
MACE_RETURN_IF_ERROR(runtime->BuildKernel("softmax", kernel_name,
built_options, &kernel_));
diff --git a/mace/ops/ops_registry.cc b/mace/ops/ops_registry.cc
index 7fc3545883ee855a578c001ac3ff75ff574261b6..26bf046391f6e3ad156a3366c75ddbd1e515c9f2 100644
--- a/mace/ops/ops_registry.cc
+++ b/mace/ops/ops_registry.cc
@@ -34,6 +34,7 @@ extern void RegisterDeconv2D(OpRegistryBase *op_registry);
extern void RegisterDepthToSpace(OpRegistryBase *op_registry);
extern void RegisterDepthwiseConv2d(OpRegistryBase *op_registry);
extern void RegisterDepthwiseDeconv2d(OpRegistryBase *op_registry);
+extern void RegisterDynamicLSTM(OpRegistryBase *op_registry);
extern void RegisterEltwise(OpRegistryBase *op_registry);
extern void RegisterExpandDims(OpRegistryBase *op_registry);
extern void RegisterFill(OpRegistryBase *op_registry);
@@ -42,9 +43,11 @@ extern void RegisterGather(OpRegistryBase *op_registry);
extern void RegisterIdentity(OpRegistryBase *op_registry);
extern void RegisterInferConv2dShape(OpRegistryBase *op_registry);
extern void RegisterLocalResponseNorm(OpRegistryBase *op_registry);
+extern void RegisterLSTMNonlinear(OpRegistryBase *op_registry);
extern void RegisterMatMul(OpRegistryBase *op_registry);
extern void RegisterOneHot(OpRegistryBase *op_registry);
extern void RegisterPad(OpRegistryBase *op_registry);
+extern void RegisterPadContext(OpRegistryBase *op_registry);
extern void RegisterPNorm(OpRegistryBase *op_registry);
extern void RegisterPooling(OpRegistryBase *op_registry);
extern void RegisterReduce(OpRegistryBase *op_registry);
@@ -68,7 +71,6 @@ extern void RegisterStack(OpRegistryBase *op_registry);
extern void RegisterStridedSlice(OpRegistryBase *op_registry);
extern void RegisterSumGroup(OpRegistryBase *op_registry);
extern void RegisterTargetRMSNorm(OpRegistryBase *op_registry);
-extern void RegisterTimeOffset(OpRegistryBase *op_registry);
extern void RegisterTranspose(OpRegistryBase *op_registry);
extern void RegisterUnstack(OpRegistryBase *op_registry);
@@ -102,6 +104,7 @@ OpRegistry::OpRegistry() : OpRegistryBase() {
ops::RegisterDepthToSpace(this);
ops::RegisterDepthwiseConv2d(this);
ops::RegisterDepthwiseDeconv2d(this);
+ ops::RegisterDynamicLSTM(this);
ops::RegisterEltwise(this);
ops::RegisterExpandDims(this);
ops::RegisterFill(this);
@@ -110,9 +113,11 @@ OpRegistry::OpRegistry() : OpRegistryBase() {
ops::RegisterIdentity(this);
ops::RegisterInferConv2dShape(this);
ops::RegisterLocalResponseNorm(this);
+ ops::RegisterLSTMNonlinear(this);
ops::RegisterMatMul(this);
ops::RegisterOneHot(this);
ops::RegisterPad(this);
+ ops::RegisterPadContext(this);
ops::RegisterPNorm(this);
ops::RegisterPooling(this);
ops::RegisterReduce(this);
@@ -136,7 +141,6 @@ OpRegistry::OpRegistry() : OpRegistryBase() {
ops::RegisterSqueeze(this);
ops::RegisterSumGroup(this);
ops::RegisterTargetRMSNorm(this);
- ops::RegisterTimeOffset(this);
ops::RegisterTranspose(this);
ops::RegisterUnstack(this);
diff --git a/mace/ops/time_offset.cc b/mace/ops/pad_context.cc
similarity index 53%
rename from mace/ops/time_offset.cc
rename to mace/ops/pad_context.cc
index d9343fc327438a965fe4b3e98a583783a6d4993a..6c463ec9830b2e22e234cef6e4ec7eddc61d9906 100644
--- a/mace/ops/time_offset.cc
+++ b/mace/ops/pad_context.cc
@@ -25,14 +25,15 @@ namespace mace {
namespace ops {
template
-class TimeOffsetOp;
+class PadContextOp;
template
-class TimeOffsetOp : public Operation {
+class PadContextOp : public Operation {
public:
- explicit TimeOffsetOp(OpConstructContext *context)
+ explicit PadContextOp(OpConstructContext *context)
: Operation(context),
- offset_(Operation::GetOptionalArg("offset", 0)) {}
+ left_context_(Operation::GetOptionalArg("left_context", 0)),
+ right_context_(Operation::GetOptionalArg("right_context", 0)) {}
MaceStatus Run(OpContext *context) override {
MACE_UNUSED(context);
@@ -41,27 +42,38 @@ class TimeOffsetOp : public Operation {
index_t rank = input->dim_size();
MACE_CHECK(rank >= 2, "input's rank should >= 2.");
+ MACE_CHECK(left_context_ > 0 && right_context_ > 0,
+ "left context and right context should be greater than zero");
const std::vector &input_shape = input->shape();
const index_t batch =
std::accumulate(input_shape.begin(), input_shape.end() - 2, 1,
std::multiplies());
- const index_t frames = input_shape[rank - 2];
- const index_t input_dim = input_shape[rank - 1];
- MACE_RETURN_IF_ERROR(output->ResizeLike(input));
+ const index_t chunk = input_shape[rank - 2];
+ const index_t dim = input_shape[rank - 1];
+ const index_t output_chunk = chunk + left_context_ + right_context_;
+ std::vector output_shape = input->shape();
+ output_shape[rank - 2] = output_chunk;
+ MACE_RETURN_IF_ERROR(output->Resize(output_shape));
Tensor::MappingGuard input_guard(input);
Tensor::MappingGuard output_guard(output);
const T *input_data = input->data();
T *output_data = output->mutable_data();
-#pragma omp parallel for collapse(2) schedule(runtime)
for (index_t i = 0; i < batch; ++i) {
- for (index_t j = 0; j < frames; ++j) {
- index_t time_index = offset_ + j;
- index_t index = Clamp(time_index, 0, frames - 1);
- T *output_base = output_data + (i * frames + j) * input_dim;
- const T *input_base = input_data + (i * frames + index) * input_dim;
- memcpy(output_base, input_base, input_dim * sizeof(T));
+ T *out_base = output_data + i * output_chunk * dim;
+ const T *in_base = input_data + i * chunk * dim;
+#pragma omp parallel for schedule(runtime)
+ for (index_t j = 0; j < left_context_; ++j) {
+ memcpy(out_base + j * dim, in_base, dim * sizeof(T));
+ }
+ out_base = out_base + left_context_ * dim;
+ memcpy(out_base, in_base, chunk * dim * sizeof(T));
+ out_base = out_base + chunk * dim;
+ in_base = in_base + (chunk -1) * dim;
+#pragma omp parallel for schedule(runtime)
+ for (index_t j = 0; j < right_context_; ++j) {
+ memcpy(out_base + j * dim, in_base, dim * sizeof(T));
}
}
@@ -69,11 +81,12 @@ class TimeOffsetOp : public Operation {
}
private:
- int offset_;
+ int left_context_;
+ int right_context_;
};
-void RegisterTimeOffset(OpRegistryBase *op_registry) {
- MACE_REGISTER_OP(op_registry, "TimeOffset", TimeOffsetOp,
+void RegisterPadContext(OpRegistryBase *op_registry) {
+ MACE_REGISTER_OP(op_registry, "PadContext", PadContextOp,
DeviceType::CPU, float);
}
diff --git a/mace/ops/pad_context_benchmark.cc b/mace/ops/pad_context_benchmark.cc
new file mode 100644
index 0000000000000000000000000000000000000000..3384f79d9f7aef9d18785cf89511de64f5ca609d
--- /dev/null
+++ b/mace/ops/pad_context_benchmark.cc
@@ -0,0 +1,79 @@
+// Copyright 2018 The MACE Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "mace/core/testing/test_benchmark.h"
+#include "mace/ops/ops_test_util.h"
+
+namespace mace {
+namespace ops {
+namespace test {
+
+namespace {
+template
+void PadContextBM(int iters,
+ const std::vector &input_shape,
+ const int left_context,
+ const int right_context) {
+ mace::testing::StopTiming();
+
+ // Construct graph
+ OpsTestNet net;
+
+ net.AddRandomInput("Input", input_shape);
+
+ OpDefBuilder("PadContext", "PadContextBM")
+ .Input("Input")
+ .Output("Output")
+ .AddIntArg("left_context", left_context)
+ .AddIntArg("right_context", right_context)
+ .AddIntArg("T", static_cast(DataTypeToEnum::value))
+ .Finalize(net.NewOperatorDef());
+
+ // Warm-up
+ for (int i = 0; i < 5; ++i) {
+ net.RunOp(D);
+ net.Sync();
+ }
+
+ mace::testing::StartTiming();
+ while (iters--) {
+ net.RunOp(D);
+ net.Sync();
+ }
+}
+} // namespace
+
+#define MACE_BM_PAD_CONTEXT_MACRO(N, H, W, L, R, TYPE, DEVICE) \
+ static void \
+ MACE_BM_PAD_CONTEXT_##N##_##H##_##W##_##L##_##R##_##TYPE##_##DEVICE( \
+ int iters) { \
+ const int64_t tot = static_cast(iters) * N * H * W; \
+ mace::testing::BytesProcessed(tot *(sizeof(TYPE))); \
+ PadContextBM(iters, {N, H, W}, L, R); \
+ } \
+ MACE_BENCHMARK( \
+ MACE_BM_PAD_CONTEXT_##N##_##H##_##W##_##L##_##R##_##TYPE##_##DEVICE)
+
+#define MACE_BM_PAD_CONTEXT(N, H, W, L, R) \
+ MACE_BM_PAD_CONTEXT_MACRO(N, H, W, L, R, float, CPU);
+
+MACE_BM_PAD_CONTEXT(1, 32, 32, 5, 5);
+MACE_BM_PAD_CONTEXT(2, 32, 32, 7, 7);
+MACE_BM_PAD_CONTEXT(1, 32, 32, 3, 3);
+MACE_BM_PAD_CONTEXT(1, 128, 128, 9, 9);
+MACE_BM_PAD_CONTEXT(3, 128, 128, 7, 7);
+
+} // namespace test
+} // namespace ops
+} // namespace mace
diff --git a/mace/ops/pad_context_test.cc b/mace/ops/pad_context_test.cc
new file mode 100644
index 0000000000000000000000000000000000000000..72f7ba576c3e8b655a3da61302147ab91aeb0fdf
--- /dev/null
+++ b/mace/ops/pad_context_test.cc
@@ -0,0 +1,87 @@
+// Copyright 2018 The MACE Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "mace/ops/ops_test_util.h"
+
+namespace mace {
+namespace ops {
+namespace test {
+
+class PadContextOpTest : public OpsTestBase {};
+
+namespace {
+template
+void TestPadContext(const std::vector &input_shape,
+ const std::vector &input,
+ const int left_context,
+ const int right_context,
+ const std::vector &output_shape,
+ const std::vector &output) {
+ OpsTestNet net;
+ net.AddInputFromArray(MakeString("Input"),
+ input_shape,
+ input);
+
+ OpDefBuilder("PadContext", "PadContextTest")
+ .Input("Input")
+ .Output("Output")
+ .AddIntArg("left_context", left_context)
+ .AddIntArg("right_context", right_context)
+ .Finalize(net.NewOperatorDef());
+
+ net.RunOp();
+
+ auto expected = net.CreateTensor(output_shape, output);
+ ExpectTensorNear(*expected, *net.GetOutput("Output"), 1e-5);
+}
+} // namespace
+
+TEST_F(PadContextOpTest, Simple2Dim) {
+ TestPadContext(
+ {3, 5},
+ {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15},
+ 2, 3, {8, 5},
+ {1, 2, 3, 4, 5,
+ 1, 2, 3, 4, 5,
+ 1, 2, 3, 4, 5,
+ 6, 7, 8, 9, 10,
+ 11, 12, 13, 14, 15,
+ 11, 12, 13, 14, 15,
+ 11, 12, 13, 14, 15,
+ 11, 12, 13, 14, 15});
+}
+
+TEST_F(PadContextOpTest, Simple3Dim) {
+ TestPadContext(
+ {2, 3, 5},
+ {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
+ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15},
+ 1, 2, {2, 6, 5},
+ {1, 2, 3, 4, 5,
+ 1, 2, 3, 4, 5,
+ 6, 7, 8, 9, 10,
+ 11, 12, 13, 14, 15,
+ 11, 12, 13, 14, 15,
+ 11, 12, 13, 14, 15,
+ 1, 2, 3, 4, 5,
+ 1, 2, 3, 4, 5,
+ 6, 7, 8, 9, 10,
+ 11, 12, 13, 14, 15,
+ 11, 12, 13, 14, 15,
+ 11, 12, 13, 14, 15});
+}
+
+} // namespace test
+} // namespace ops
+} // namespace mace
diff --git a/mace/ops/pnorm.cc b/mace/ops/pnorm.cc
index 8742a3b4492cb36aab4deece867a2021c4afd106..6964c6810bac50e59350410f009ac85c85f44ed6 100644
--- a/mace/ops/pnorm.cc
+++ b/mace/ops/pnorm.cc
@@ -48,8 +48,6 @@ class PNormOp : public Operation {
const Tensor *input = this->Input(0);
Tensor *output = this->Output(0);
-
-
const std::vector &input_shape = input->shape();
const index_t dim_size = input_shape.size();
MACE_CHECK(dim_size >= 1, "PNorm only supports input dim size >= 1");
diff --git a/mace/ops/pnorm_benchmark.cc b/mace/ops/pnorm_benchmark.cc
index e3af765cd22f3589abd602dc6e28cd96acc2ee0f..e1efd0eb4052981188ae703f8044913c54afe671 100644
--- a/mace/ops/pnorm_benchmark.cc
+++ b/mace/ops/pnorm_benchmark.cc
@@ -57,7 +57,7 @@ void PNormBenchmark(int iters, int n, int h, int w, int p, int ow) {
int iters) { \
const int64_t tot = static_cast(iters) * N * H * W; \
mace::testing::BytesProcessed(tot *(sizeof(TYPE))); \
- PNormBenchmark(iters, N, H, W, P, OW); \
+ PNormBenchmark(iters, N, H, W, P, OW); \
} \
MACE_BENCHMARK( \
MACE_BM_PNORM_##N##_##H##_##W##_##P##_##OW##_##TYPE##_##DEVICE)
diff --git a/mace/ops/slice.cc b/mace/ops/slice.cc
index f38a2a32a861a2ca20882268bc98d96fca55d6d7..f990912d0ce1f02ea65ab95d2334cf411aee2750 100644
--- a/mace/ops/slice.cc
+++ b/mace/ops/slice.cc
@@ -40,19 +40,18 @@ class SliceOp : public Operation {
const index_t rank = input->dim_size();
MACE_CHECK(rank >= 1)
<< "The input dim size should >= 1";
+ const index_t input_dim = input->dim(rank - 1);
MACE_CHECK(starts_.size() == 1 && ends_.size() == 1 && axes_.size() == 1,
"only support slicing at one axis.");
MACE_CHECK(axes_[0] == -1 || axes_[0] == rank - 1,
"only support slicing at the last axis.");
- const index_t input_dim = input->dim(rank - 1);
+ MACE_CHECK(starts_[0] < input_dim && starts_[0] >= 0
+ && ends_[0] >= 0
+ && ends_[0] <= input_dim)
+ << "The starts and ends caused over range error.";
const index_t offset = starts_[0];
const index_t output_dim = ends_[0] - starts_[0];
-
MACE_CHECK(output_dim >= 0, "output_dim should >= 0");
- MACE_CHECK(starts_[0] < input_dim
- && output_dim <= input_dim
- && ends_[0] <= input_dim)
- << "The starts and ends caused over range error.";
const index_t frames =
std::accumulate(input->shape().begin(), input->shape().end() - 1, 1,
diff --git a/mace/ops/slice_benchmark.cc b/mace/ops/slice_benchmark.cc
new file mode 100644
index 0000000000000000000000000000000000000000..fd2a383934964c37d132ec9fc4fe6dd731f490b6
--- /dev/null
+++ b/mace/ops/slice_benchmark.cc
@@ -0,0 +1,76 @@
+// Copyright 2018 The MACE Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "mace/core/testing/test_benchmark.h"
+#include "mace/ops/ops_test_util.h"
+
+namespace mace {
+namespace ops {
+namespace test {
+
+namespace {
+template
+void BMSliceHelper(int iters,
+ const std::vector &input_shape,
+ const int offset,
+ const int output_dim) {
+ mace::testing::StopTiming();
+ // Construct graph
+ OpsTestNet net;
+ net.AddRandomInput("Input", input_shape);
+ OpDefBuilder("Slice", "SliceBM")
+ .Input("Input")
+ .Output("Output")
+ .AddIntArg("offset", offset)
+ .AddIntArg("output_dim", output_dim)
+ .AddIntArg("T", static_cast(DataTypeToEnum::value))
+ .Finalize(net.NewOperatorDef());
+
+ // Warm-up
+ for (int i = 0; i < 5; ++i) {
+ net.RunOp(D);
+ net.Sync();
+ }
+
+ mace::testing::StartTiming();
+ while (iters--) {
+ net.RunOp(D);
+ net.Sync();
+ }
+}
+} // namespace
+
+#define MACE_BM_SLICE_MACRO(N, H, W, S, D, TYPE, DEVICE) \
+ static void \
+ MACE_BM_SLICE_##N##_##H##_##W##_##S##_##D##_##TYPE##_##DEVICE( \
+ int iters) { \
+ const int64_t tot = static_cast(iters) * N * H * W; \
+ mace::testing::BytesProcessed(tot *(sizeof(TYPE))); \
+ BMSliceHelper(iters, {N, H, W}, S, D); \
+ } \
+ MACE_BENCHMARK( \
+ MACE_BM_SLICE_##N##_##H##_##W##_##S##_##D##_##TYPE##_##DEVICE)
+
+#define MACE_BM_SLICE(N, H, W, S, D) \
+ MACE_BM_SLICE_MACRO(N, H, W, S, D, float, CPU);
+
+MACE_BM_SLICE(1, 32, 32, 5, 5);
+MACE_BM_SLICE(1, 32, 32, 7, 5);
+MACE_BM_SLICE(1, 32, 32, 3, 20);
+MACE_BM_SLICE(1, 128, 128, 9, 100);
+MACE_BM_SLICE(1, 128, 128, 7, 100);
+
+} // namespace test
+} // namespace ops
+} // namespace mace
diff --git a/mace/ops/softmax.cc b/mace/ops/softmax.cc
index cbab37adf5ebe9e0a3195483cecc287be5931bd0..427a29eb850c3a5577c4fd57a5b49e401e255b51 100644
--- a/mace/ops/softmax.cc
+++ b/mace/ops/softmax.cc
@@ -42,7 +42,8 @@ template <>
class SoftmaxOp : public Operation {
public:
explicit SoftmaxOp(OpConstructContext *context)
- : Operation(context) {}
+ : Operation(context),
+ use_log_(Operation::GetOptionalArg("use_log", false)) {}
MaceStatus Run(OpContext *context) override {
MACE_UNUSED(context);
@@ -88,9 +89,18 @@ class SoftmaxOp : public Operation {
sum = std::max(sum, std::numeric_limits::min());
channel_offset = 0;
- for (index_t c = 0; c < class_count; ++c) {
- output_ptr[channel_offset] /= sum;
- channel_offset += class_size;
+ if (use_log_) {
+ for (index_t c = 0; c < class_count; ++c) {
+ output_ptr[channel_offset] /= sum;
+ output_ptr[channel_offset] =
+ std::log(output_ptr[channel_offset]);
+ channel_offset += class_size;
+ }
+ } else {
+ for (index_t c = 0; c < class_count; ++c) {
+ output_ptr[channel_offset] /= sum;
+ channel_offset += class_size;
+ }
}
} // k
} // b
@@ -123,8 +133,15 @@ class SoftmaxOp : public Operation {
}
sum = std::max(sum, std::numeric_limits::min());
- for (index_t c = 0; c < class_count; ++c) {
- output_ptr[c] /= sum;
+ if (use_log_) {
+ for (index_t c = 0; c < class_count; ++c) {
+ output_ptr[c] /= sum;
+ output_ptr[c] = std::log(output_ptr[c]);
+ }
+ } else {
+ for (index_t c = 0; c < class_count; ++c) {
+ output_ptr[c] /= sum;
+ }
}
}
} else {
@@ -132,6 +149,9 @@ class SoftmaxOp : public Operation {
}
return MaceStatus::MACE_SUCCESS;
}
+
+ protected:
+ bool use_log_;
};
#ifdef MACE_ENABLE_QUANTIZE
@@ -142,10 +162,12 @@ template <>
class SoftmaxOp : public Operation {
public:
explicit SoftmaxOp(OpConstructContext *context)
- : Operation(context) {}
+ : Operation(context),
+ use_log_(Operation::GetOptionalArg("use_log", false)) {}
MaceStatus Run(OpContext *context) override {
MACE_UNUSED(context);
+ MACE_CHECK(!use_log_, "MACE dose not support quantized logsoftmax yet.");
const Tensor *input = this->Input(0);
Tensor *output = this->Output(0);
MACE_RETURN_IF_ERROR(output->ResizeLike(input));
@@ -366,6 +388,9 @@ class SoftmaxOp : public Operation {
}
return MaceStatus::MACE_SUCCESS;
}
+
+ protected:
+ bool use_log_;
};
#endif // MACE_ENABLE_QUANTIZE
@@ -375,11 +400,13 @@ class SoftmaxOp : public Operation {
public:
explicit SoftmaxOp(OpConstructContext *context)
: Operation(context) {
+ bool use_log = (
+ Operation::GetOptionalArg("use_log", false));
if (context->device()->gpu_runtime()->UseImageMemory()) {
- kernel_ = make_unique>();
+ kernel_ = make_unique>(use_log);
} else {
context->set_output_mem_type(MemoryType::GPU_BUFFER);
- kernel_ = make_unique>();
+ kernel_ = make_unique>(use_log);
}
}
MaceStatus Run(OpContext *context) override {
diff --git a/mace/ops/softmax_test.cc b/mace/ops/softmax_test.cc
index ee64820a0f4cb63e68ddb148f8df8c3a85bb332d..ab818ac8d55b5c0b277c41fb0044797666ee4bce 100644
--- a/mace/ops/softmax_test.cc
+++ b/mace/ops/softmax_test.cc
@@ -12,6 +12,13 @@
// See the License for the specific language governing permissions and
// limitations under the License.
+// python implementation
+// import numpy as np
+// x = np.asarray([1., 1., 1., 1.], 'f')
+// exp_x = np.exp(x)
+// softmax_x = exp_x / np.sum(exp_x)
+// log_softmax_x = np.log(softmax_x)
+
#include "mace/ops/ops_test_util.h"
namespace mace {
@@ -19,18 +26,27 @@ namespace ops {
namespace test {
class SoftmaxOpTest : public OpsTestBase {};
+class LogSoftmaxOpTest : public OpsTestBase {};
namespace {
template
-void Simple() {
+void Simple(bool use_log = false) {
// Construct graph
OpsTestNet net;
// Add input data
net.AddInputFromArray("Input", {1, 1, 2, 4},
{1, 1, 1, 1, 1, 2, 3, 4});
+
+ std::vector expected_data(8);
+ if (use_log) {
+ expected_data = {-1.3862944, -1.3862944, -1.3862944, -1.3862944,
+ -3.4401896 , -2.4401896 , -1.4401897 , -0.44018975};
+ } else {
+ expected_data = {0.25, 0.25, 0.25, 0.25,
+ 0.0320586, 0.08714432, 0.23688282, 0.6439142};
+ }
auto expected = net.CreateTensor(
- {1, 1, 2, 4},
- {0.25, 0.25, 0.25, 0.25, 0.0320586, 0.08714432, 0.23688282, 0.64391426});
+ {1, 1, 2, 4}, expected_data);
if (D == DeviceType::CPU) {
// test 4d softmax
@@ -38,6 +54,7 @@ void Simple() {
OpDefBuilder("Softmax", "SoftmaxTest")
.Input("InputNCHW")
.Output("OutputNCHW")
+ .AddIntArg("use_log", static_cast(use_log))
.Finalize(net.NewOperatorDef());
// Run
@@ -52,6 +69,7 @@ void Simple() {
OpDefBuilder("Softmax", "SoftmaxTest")
.Input("Input2d")
.Output("Output")
+ .AddIntArg("use_log", static_cast(use_log))
.Finalize(net.NewOperatorDef());
// Run
@@ -62,6 +80,7 @@ void Simple() {
OpDefBuilder("Softmax", "SoftmaxTest")
.Input("Input")
.Output("Output")
+ .AddIntArg("use_log", static_cast(use_log))
.Finalize(net.NewOperatorDef());
// Run
@@ -77,9 +96,13 @@ void Simple() {
TEST_F(SoftmaxOpTest, CPUSimple) { Simple(); }
TEST_F(SoftmaxOpTest, OPENCLSimple) { Simple(); }
+TEST_F(LogSoftmaxOpTest, CPUSimple) { Simple(true); }
+TEST_F(LogSoftmaxOpTest, OPENCLSimple) { Simple(true); }
+
namespace {
template
-void Complex(const std::vector &logits_shape) {
+void Complex(const std::vector &logits_shape,
+ bool use_log = false) {
// Construct graph
OpsTestNet net;
// Add input data
@@ -91,11 +114,13 @@ void Complex(const std::vector &logits_shape) {
OpDefBuilder("Softmax", "SoftmaxTest")
.Input("InputNCHW")
.Output("OutputNCHW")
+ .AddIntArg("use_log", static_cast(use_log))
.Finalize(net.NewOperatorDef());
} else {
OpDefBuilder("Softmax", "SoftmaxTest")
.Input("Input")
.Output("Output")
+ .AddIntArg("use_log", static_cast(use_log))
.Finalize(net.NewOperatorDef());
}
// Run on cpu
@@ -111,6 +136,7 @@ void Complex(const std::vector &logits_shape) {
OpDefBuilder("Softmax", "SoftmaxTest")
.Input("Input")
.Output("Output")
+ .AddIntArg("use_log", static_cast(use_log))
.Finalize(net.NewOperatorDef());
// Run on gpu
@@ -140,6 +166,26 @@ TEST_F(SoftmaxOpTest, OPENCLAlignedRank2) {
Complex({3, 1001});
}
+TEST_F(LogSoftmaxOpTest, OPENCLAligned) {
+Complex({1, 256, 256, 3}, true);
+Complex({1, 128, 128, 16}, true);
+}
+
+TEST_F(LogSoftmaxOpTest, OPENCLMulBatchAligned) {
+Complex({5, 64, 64, 3}, true);
+Complex({8, 128, 128, 8}, true);
+}
+
+TEST_F(LogSoftmaxOpTest, OPENCLUnAligned) {
+Complex({1, 113, 107, 13}, true);
+Complex({5, 211, 107, 1}, true);
+}
+
+TEST_F(LogSoftmaxOpTest, OPENCLAlignedRank2) {
+Complex({1, 1001}, true);
+Complex({3, 1001}, true);
+}
+
namespace {
void TestQuantizedSoftmax(const std::vector &input_shape) {
diff --git a/mace/ops/splice.cc b/mace/ops/splice.cc
index bf1dfe36b41e8c79675f3d75b0578fa2ce76816e..8517b6b831c80397086e0598f8803aeff0be81ce 100644
--- a/mace/ops/splice.cc
+++ b/mace/ops/splice.cc
@@ -14,16 +14,11 @@
// This Op is for SpliceComponent in Kaldi.
// It splices a context window of frames together [over time]
-// (copy and append the frame whose time-index in in context_)
+// (copy and append the frame whose time-index is in context_)
// The context_ values indicate which frame (over time) to splice.
-// if context value is less than the first time-index,
-// copy and append the first frame's dada,
-// when context value is larger than frame's count,
-// copy and append the last frame's data.
-// i.e., give input data: [[1, 2, 3], [4, 5, 6]],
-// with input-dim = 3, frame count = 2, context = [-1, 0, 1]
-// Then, the output should be:
-// [1, 2, 3, 1, 2, 3, 4, 5, 6], [1, 2, 3, 4, 5, 6, 4, 5, 6]
+// It will reduce frames because of left context and right context.
+// i.e., give input data with shape {20, 40}, and contexts:{-2, -1, 0, 1, 2},
+// the output shape should be {16, 200}
// if const_component_dim_ != 0, const_dim_ will be used to determine which
// row of "in" we copy the last part of each row of "out" from (this part is
// not subject to splicing, it's assumed constant for each frame of "input".
@@ -54,24 +49,34 @@ class SpliceOp : public Operation {
const Tensor *input = this->Input(0);
MACE_CHECK(context_.size() > 0)
<< "The context param should not be empty in Splice Op.";
+ MACE_CHECK(input->dim_size() >= 2)
+ << "Splice's input's rank should be greater than 2.";
Tensor *output = this->Output(0);
const std::vector &input_shape = input->shape();
- const index_t frames =
- std::accumulate(input->shape().begin(), input->shape().end() - 1, 1,
+ const index_t batch =
+ std::accumulate(input->shape().begin(), input->shape().end() - 2, 1,
std::multiplies());
-
const index_t rank = input->dim_size();
+ const index_t chunk = input_shape[rank - 2];
const index_t input_dim = input_shape[rank - 1];
+ const index_t input_stride = chunk * input_dim;
const index_t num_splice = static_cast(context_.size());
const index_t dim = input_dim - const_dim_;
+ const index_t left_context = context_[0];
+ const index_t right_context = context_[num_splice -1];
+
+ const index_t out_chunk = chunk - (right_context - left_context);
+
MACE_CHECK(input_dim > const_dim_,
"input dim should be greater than const dim.");
const index_t output_dim = dim * num_splice + const_dim_;
+ const index_t output_stride = out_chunk * output_dim;
std::vector output_shape = input->shape();
+ output_shape[rank - 2] = out_chunk;
output_shape[rank - 1] = output_dim;
MACE_RETURN_IF_ERROR(output->Resize(output_shape));
@@ -80,28 +85,32 @@ class SpliceOp : public Operation {
const T *input_data = input->data();
T *output_data = output->mutable_data();
-#pragma omp parallel for collapse(2) schedule(runtime)
- for (index_t i = 0; i < frames; ++i) {
+#pragma omp parallel for collapse(3) schedule(runtime)
+ for (int b = 0; b < batch; ++b) {
+ for (index_t i = 0; i < out_chunk; ++i) {
for (index_t c = 0; c < num_splice; ++c) {
- const index_t offset =
- Clamp(context_[c] + i, 0, frames - 1);
- T *output_base = output_data + i * output_dim + c * dim;
- const T *input_base = input_data + offset * input_dim;
+ const index_t offset = i + context_[c] - left_context;
+ T *output_base =
+ output_data + b * output_stride + i * output_dim + c * dim;
+ const T *input_base =
+ input_data + b * input_stride + offset * input_dim;
memcpy(output_base, input_base, dim * sizeof(T));
}
}
+ }
if (const_dim_ > 0) {
const index_t output_offset = output_dim - const_dim_;
const index_t input_offset = dim;
-#pragma omp parallel for schedule(runtime)
- for (index_t i = 0; i < frames; ++i) {
- index_t offset = i + context_[0] >= 0 ? i + context_[0] : 0;
- T *output_base = output_data + i * output_dim;
- const T *input_base = input_data + offset * input_dim;
+#pragma omp parallel for collapse(2) schedule(runtime)
+ for (int b = 0; b < batch; ++b) {
+ for (index_t i = 0; i < out_chunk; ++i) {
+ T *output_base = output_data + + b * output_stride + i * output_dim;
+ const T *input_base = input_data + b * input_stride + i * input_dim;
memcpy(output_base + output_offset,
input_base + input_offset,
const_dim_ * sizeof(T));
+ }
}
}
return MaceStatus::MACE_SUCCESS;
diff --git a/mace/ops/splice_benchmark.cc b/mace/ops/splice_benchmark.cc
index 253808b8385e1526432cfdc3cd5befd98f70736b..2a90493dcf2535e341db49fe848127bf70bd1929 100644
--- a/mace/ops/splice_benchmark.cc
+++ b/mace/ops/splice_benchmark.cc
@@ -23,15 +23,15 @@ namespace {
template
void BMSpliceHelper(int iters,
const std::vector &input_shape,
- const index_t left_context,
- const index_t right_context,
+ const int left_context,
+ const int right_context,
const int const_component_dim) {
mace::testing::StopTiming();
// Construct graph
OpsTestNet net;
- const int num_splice = left_context + right_context + 1;
+ const index_t num_splice = left_context + right_context + 1;
std::vector contexts(num_splice);
for (int i = 0; i < num_splice; ++i) {
contexts[i] = left_context + i;
@@ -44,7 +44,7 @@ void BMSpliceHelper(int iters,
GenerateRandomRealTypeData(input_shape, &input_data);
net.AddInputFromArray("Input", input_shape, input_data);
- OpDefBuilder("Splice", "SpliceTest")
+ OpDefBuilder("Splice", "SpliceBM")
.Input("Input")
.Output("Output")
.AddIntsArg("context", contexts)
@@ -71,7 +71,6 @@ void BMSpliceHelper(int iters,
MACE_BM_SPLICE_##N##_##H##_##W##_##L##_##R##_##C##_##TYPE##_##DEVICE( \
int iters) { \
const int64_t tot = static_cast(iters) * N * H * W; \
- mace::testing::MacsProcessed(tot); \
mace::testing::BytesProcessed(tot *(sizeof(TYPE))); \
BMSpliceHelper(iters, {N, H, W}, L, R, C); \
} \
diff --git a/mace/ops/splice_test.cc b/mace/ops/splice_test.cc
index 60e1652a394d7d1a7b88c0b1f537ec5fc688d613..b6bc3d32c179a4475f0c58c921138be901ee8c2b 100644
--- a/mace/ops/splice_test.cc
+++ b/mace/ops/splice_test.cc
@@ -53,14 +53,10 @@ TEST_F(SpliceOpTest, WithoutConstDim) {
{1, 7, 2},
{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14},
{-2, -1, 0, 1, 2}, 0,
- {1, 7, 10},
- {1, 2, 1, 2, 1, 2, 3, 4, 5, 6,
- 1, 2, 1, 2, 3, 4, 5, 6, 7, 8,
- 1, 2, 3, 4, 5, 6, 7, 8, 9, 10,
+ {1, 3, 10},
+ {1, 2, 3, 4, 5, 6, 7, 8, 9, 10,
3, 4, 5, 6, 7, 8, 9, 10, 11, 12,
- 5, 6, 7, 8, 9, 10, 11, 12, 13, 14,
- 7, 8, 9, 10, 11, 12, 13, 14, 13, 14,
- 9, 10, 11, 12, 13, 14, 13, 14, 13, 14});
+ 5, 6, 7, 8, 9, 10, 11, 12, 13, 14});
}
TEST_F(SpliceOpTest, WithConstDim) {
@@ -72,12 +68,8 @@ TEST_F(SpliceOpTest, WithConstDim) {
4, 5, 6, 7, 8, 9, 10, 11, 12, 13,
5, 6, 7, 8, 9, 10, 11, 12, 13, 14},
{-2, -1, 0, 1, 2}, 7,
- {1, 5, 22},
- {1, 2, 3, 1, 2, 3, 1, 2, 3, 2, 3, 4, 3, 4, 5, 4, 5, 6, 7, 8, 9, 10,
- 1, 2, 3, 1, 2, 3, 2, 3, 4, 3, 4, 5, 4, 5, 6, 4, 5, 6, 7, 8, 9, 10,
- 1, 2, 3, 2, 3, 4, 3, 4, 5, 4, 5, 6, 5, 6, 7, 4, 5, 6, 7, 8, 9, 10,
- 2, 3, 4, 3, 4, 5, 4, 5, 6, 5, 6, 7, 5, 6, 7, 5, 6, 7, 8, 9, 10, 11,
- 3, 4, 5, 4, 5, 6, 5, 6, 7, 5, 6, 7, 5, 6, 7, 6, 7, 8, 9, 10, 11, 12});
+ {1, 1, 22},
+ {1, 2, 3, 2, 3, 4, 3, 4, 5, 4, 5, 6, 5, 6, 7, 4, 5, 6, 7, 8, 9, 10});
}
} // namespace test
} // namespace ops
diff --git a/mace/ops/time_offset_benchmark.cc b/mace/ops/time_offset_benchmark.cc
deleted file mode 100644
index 82ea9967a9bd95542f012666593e81005cd64c48..0000000000000000000000000000000000000000
--- a/mace/ops/time_offset_benchmark.cc
+++ /dev/null
@@ -1,78 +0,0 @@
-// Copyright 2018 The MACE Authors. All Rights Reserved.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include
-#include
-
-#include "mace/core/testing/test_benchmark.h"
-#include "mace/ops/ops_test_util.h"
-
-namespace mace {
-namespace ops {
-namespace test {
-
-namespace {
-template
-void TimeOffsetBenchmark(int iters,
- std::vector shape,
- int offset) {
- mace::testing::StopTiming();
-
- OpsTestNet net;
-
- // Add input data
- net.AddRandomInput("Input", shape);
-
- OpDefBuilder("TimeOffset", "TimeOffsetBM")
- .Input("Input")
- .Output("Output")
- .AddIntArg("offset", offset)
- .Finalize(net.NewOperatorDef());
-
- // Warm-up
- for (int i = 0; i < 5; ++i) {
- net.RunOp(D);
- }
- net.Sync();
-
- mace::testing::StartTiming();
- while (iters--) {
- net.RunOp(D);
- }
- net.Sync();
-}
-} // namespace
-
-#define MACE_BM_TIMEOFFSET2D_MACRO(H, W, TYPE, DEVICE) \
- static void MACE_BM_TIMEOFFSET2D_##H##_##W##_##TYPE##_##DEVICE(\
- int iters) { \
- const int64_t tot = static_cast(iters) * H * W; \
- mace::testing::BytesProcessed(tot *(sizeof(TYPE))); \
- TimeOffsetBenchmark(iters, {H, W}, 1); \
- } \
- MACE_BENCHMARK(MACE_BM_TIMEOFFSET2D_##H##_##W##_##TYPE##_##DEVICE) \
-
-#define MACE_BM_TIMEOFFSET2D(H, W) \
- MACE_BM_TIMEOFFSET2D_MACRO(H, W, float, CPU);
-
-
-MACE_BM_TIMEOFFSET2D(20, 128);
-MACE_BM_TIMEOFFSET2D(40, 512);
-MACE_BM_TIMEOFFSET2D(1, 1024);
-MACE_BM_TIMEOFFSET2D(20, 2048);
-MACE_BM_TIMEOFFSET2D(20, 512);
-
-} // namespace test
-} // namespace ops
-} // namespace mace
diff --git a/mace/ops/time_offset_test.cc b/mace/ops/time_offset_test.cc
deleted file mode 100644
index b32b8c52acf3b8af715dac74f92d4a87efe1a102..0000000000000000000000000000000000000000
--- a/mace/ops/time_offset_test.cc
+++ /dev/null
@@ -1,125 +0,0 @@
-// Copyright 2018 The MACE Authors. All Rights Reserved.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include "mace/ops/ops_test_util.h"
-
-namespace mace {
-namespace ops {
-namespace test {
-
-class TimeOffsetOpTest : public OpsTestBase {};
-
-namespace {
-template
-void TestTimeOffset(const std::vector &input_shape,
- const std::vector &input,
- const int offset,
- const std::vector &output) {
- OpsTestNet net;
- net.AddInputFromArray(MakeString("Input"),
- input_shape,
- input);
-
- OpDefBuilder("TimeOffset", "TimeOffsetTest")
- .Input("Input")
- .Output("Output")
- .AddIntArg("offset", offset)
- .Finalize(net.NewOperatorDef());
-
- net.RunOp();
-
- net.AddInputFromArray("ExpectedOutput", input_shape, output);
- ExpectTensorNear(*net.GetOutput("ExpectedOutput"),
- *net.GetOutput("Output"));
-}
-} // namespace
-
-TEST_F(TimeOffsetOpTest, Simple2Dim) {
- TestTimeOffset(
- {3, 5},
- {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15},
- -2,
- {1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 1, 2, 3, 4, 5});
-
- TestTimeOffset(
- {3, 5},
- {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15},
- -1,
- {1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10});
-
- TestTimeOffset(
- {3, 5},
- {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15},
- 0,
- {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15});
-
- TestTimeOffset(
- {3, 5},
- {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15},
- 1,
- {6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 11, 12, 13, 14, 15});
-
- TestTimeOffset(
- {3, 5},
- {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15},
- 2,
- {11, 12, 13, 14, 15, 11, 12, 13, 14, 15, 11, 12, 13, 14, 15});
-}
-
-
-TEST_F(TimeOffsetOpTest, Simple3Dim) {
- TestTimeOffset(
- {2, 3, 5},
- {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
- 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15},
- -2,
- {1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 1, 2, 3, 4, 5,
- 1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 1, 2, 3, 4, 5});
-
- TestTimeOffset(
- {2, 3, 5},
- {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
- 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15},
- -1,
- {1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10,
- 1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10});
-
- TestTimeOffset(
- {2, 3, 5},
- {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
- 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15},
- 0,
- {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
- 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15});
-
- TestTimeOffset(
- {2, 3, 5},
- {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
- 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15},
- 1,
- {6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 11, 12, 13, 14, 15,
- 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 11, 12, 13, 14, 15});
-
- TestTimeOffset(
- {2, 3, 5},
- {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
- 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15},
- 2,
- {11, 12, 13, 14, 15, 11, 12, 13, 14, 15, 11, 12, 13, 14, 15,
- 11, 12, 13, 14, 15, 11, 12, 13, 14, 15, 11, 12, 13, 14, 15});
-}
-
-} // namespace test
-} // namespace ops
-} // namespace mace
diff --git a/mace/python/tools/converter_tool/base_converter.py b/mace/python/tools/converter_tool/base_converter.py
index 7fc877d662a90bc4d6030daab3843b27cb801f80..bb1cb88a4f6ea9207e8defa058abce65140a11bc 100644
--- a/mace/python/tools/converter_tool/base_converter.py
+++ b/mace/python/tools/converter_tool/base_converter.py
@@ -102,7 +102,6 @@ class FrameworkType(Enum):
MaceSupportedOps = [
'Activation',
'AddN',
- 'Affine',
'ArgMax',
'BatchNorm',
'BatchToSpaceND',
@@ -126,10 +125,12 @@ MaceSupportedOps = [
'InferConv2dShape',
'LocalResponseNorm',
'LSTMCell',
- # 'LstmNonlinear',
+ 'LstmNonlinear',
+ 'DynamicLSTM',
'MatMul',
'OneHot',
'Pad',
+ 'PadContext',
'PNorm',
'Pooling',
'PriorBox',
@@ -156,7 +157,6 @@ MaceSupportedOps = [
'SqrDiffMean',
'SumGroup',
'TargetRMSNorm',
- 'TimeOffset',
'Transpose',
'WinogradInverseTransform',
'WinogradTransform',
diff --git a/mace/python/tools/converter_tool/onnx_converter.py b/mace/python/tools/converter_tool/onnx_converter.py
index 68f781a23dfc4fe5d09163b59422be15fec31f87..805cbd272731e1548a11dc23a284b305f2a2ee14 100644
--- a/mace/python/tools/converter_tool/onnx_converter.py
+++ b/mace/python/tools/converter_tool/onnx_converter.py
@@ -72,6 +72,7 @@ OnnxSupportedOps = [
'DimRange',
'Div',
'Dropout',
+ 'DynamicLstmCell',
'Elu',
'Equal',
# 'Exp',
@@ -90,16 +91,16 @@ OnnxSupportedOps = [
# 'Hardmax',
'Identity',
# 'If',
- 'IfDefined',
+ # 'IfDefined',
'ImageScaler',
# 'InstanceNormalization',
# 'LRN',
'LSTM',
- # 'LstmNonlinear',
+ 'LstmNonlinear',
'LeakyRelu',
# 'Less',
# 'Log',
- # 'LogSoftmax',
+ 'LogSoftmax',
# 'Loop',
# 'LpNormalization',
# 'LpPool',
@@ -120,6 +121,7 @@ OnnxSupportedOps = [
# 'Or',
'PRelu',
# 'Pad',
+ 'PadContext',
'Padding',
'PNorm',
'Pow',
@@ -331,12 +333,11 @@ class OnnxConverter(base_converter.ConverterInterface):
OnnxOpType.GlobalAveragePool.name: self.convert_reduce,
OnnxOpType.GlobalMaxPool.name: self.convert_reduce,
OnnxOpType.Identity.name: self.convert_identity,
- OnnxOpType.IfDefined.name: self.convert_identity,
OnnxOpType.ImageScaler.name: self.convert_imagescaler,
OnnxOpType.LeakyRelu.name: self.convert_activation,
- # OnnxOpType.LogSoftmax.name: self.convert_softmax,
- OnnxOpType.LSTM.name: self.convert_lstm,
- # OnnxOpType.LstmNonlinear.name: self.convert_lstm_nonlinear,
+ OnnxOpType.LogSoftmax.name: self.convert_softmax,
+ OnnxOpType.LstmNonlinear.name: self.convert_lstm_nonlinear,
+ OnnxOpType.DynamicLstmCell.name: self.convert_dynamic_lstm,
OnnxOpType.Max.name: self.convert_eltwise,
OnnxOpType.MaxPool.name: self.convert_pooling,
OnnxOpType.MatMul.name: self.convert_matmul,
@@ -344,7 +345,8 @@ class OnnxConverter(base_converter.ConverterInterface):
OnnxOpType.Mul.name: self.convert_eltwise,
OnnxOpType.Neg.name: self.convert_eltwise,
OnnxOpType.Normalize: self.convert_normalize,
- OnnxOpType.Offset.name: self.convert_timeoffset,
+ OnnxOpType.Offset.name: self.convert_identity,
+ OnnxOpType.PadContext.name: self.convert_pad_context,
OnnxOpType.Padding.name: self.convert_identity,
OnnxOpType.PNorm.name: self.convert_pnorm,
OnnxOpType.Pow.name: self.convert_eltwise,
@@ -642,7 +644,7 @@ class OnnxConverter(base_converter.ConverterInterface):
mace_check(axis_value == 1 or axis_value == -3,
"only support concat at channel dimension")
elif node.op_type == OnnxOpType.Append.name:
- axis_value = 2
+ axis_value = 1
axis_arg = op.arg.add()
axis_arg.name = MaceKeyword.mace_axis_str
axis_arg.i = 4 + axis_value if axis_value < 0 else axis_value
@@ -758,14 +760,69 @@ class OnnxConverter(base_converter.ConverterInterface):
offset = node.attrs['offset']
starts_arg = op.arg.add()
starts_arg.name = 'starts'
- starts_arg.ints.append(offset)
+ starts_arg.ints.extend([offset])
output_dim = node.attrs['output_dim']
ends_arg = op.arg.add()
- ends_arg.name = 'output_dim'
- ends_arg.ints.append(output_dim)
+ ends_arg.name = 'ends'
+ ends_arg.ints.extend([output_dim + offset])
axes_arg = op.arg.add()
axes_arg.name = 'axes'
- axes_arg.ints.append(-1)
+ axes_arg.ints.extend([-1])
+
+ def convert_dynamic_lstm(self, node):
+ op = self.convert_general_op(node)
+ op.type = MaceOp.DynamicLSTM.name
+
+ if 'delay_a' in node.attrs:
+ prev_out_delay = node.attrs['delay_a']
+ mace_check(prev_out_delay < 0,
+ "dynamic's prev_out_delay should <= 0.")
+ prev_out_delay_arg = op.arg.add()
+ prev_out_delay_arg.name = 'prev_out_delay'
+ prev_out_delay_arg.i = prev_out_delay
+ if 'delay_b' in node.attrs:
+ prev_cell_delay = node.attrs['delay_b']
+ mace_check(prev_cell_delay < 0,
+ "dynamic's prev_cell_delay should < 0.")
+ prev_cell_delay_arg = op.arg.add()
+ prev_cell_delay_arg.name = 'prev_cell_delay'
+ prev_cell_delay_arg.i = prev_cell_delay
+ if 'prev_out_offset' in node.attrs:
+ prev_out_offset = node.attrs['prev_out_offset']
+ mace_check(pre_out_offset >= 0,
+ "dynamic's prev_out_offset should >= 0.")
+ prev_out_offset_arg = op.arg.add()
+ prev_out_offset_arg.name = 'prev_out_offset'
+ prev_out_offset_arg.i = prev_out_offset
+ if 'prev_a_dim' in node.attrs:
+ prev_out_dim = node.attrs['prev_a_dim']
+ mace_check(prev_out_dim > 0,
+ "dynamic's prev_out_dim should > 0.")
+ prev_out_dim_arg = op.arg.add()
+ prev_out_dim_arg.name = 'prev_out_dim'
+ prev_out_dim_arg.i = prev_out_dim
+ if 'prev_b_dim' in node.attrs:
+ prev_cell_dim = node.attrs['prev_b_dim']
+ mace_check(prev_cell_dim > 0,
+ "dynamic's prev_cell_dim should > 0.")
+ prev_cell_dim_arg = op.arg.add()
+ prev_cell_dim_arg.name = 'prev_cell_dim'
+ prev_cell_dim_arg.i = prev_cell_dim
+ if 'bias_a' in node.attrs:
+ bias_a = node.attrs['bias_a']
+ bias_a_arg = op.arg.add()
+ bias_a_arg.name = 'bias_a'
+ bias_a_arg.i = bias_a
+ if 'bias_b' in node.attrs:
+ bias_b = node.attrs['bias_b']
+ bias_b_arg = op.arg.add()
+ bias_b_arg.name = 'bias_b'
+ bias_b_arg.i = bias_b
+ if 'scale' in node.attrs:
+ scale = node.attrs['scale']
+ scale_arg = op.arg.add()
+ scale_arg.name = 'scale'
+ scale_arg.f = scale
def convert_eltwise(self, node):
op = self.convert_general_op(node)
@@ -925,6 +982,18 @@ class OnnxConverter(base_converter.ConverterInterface):
op = self.convert_general_op(node)
op.type = MaceOp.BatchNorm.name
+ def convert_pad_context(self, node):
+ op = self.convert_general_op(node)
+ op.type = MaceOp.PadContext.name
+ if 'left_context' in node.attrs:
+ left_context_arg = op.arg.add()
+ left_context_arg.name = 'left_context'
+ left_context_arg.i = node.attrs['left_context']
+ if 'right_context' in node.attrs:
+ right_context_arg = op.arg.add()
+ right_context_arg.name = 'right_context'
+ right_context_arg.i = node.attrs['right_context']
+
def convert_pnorm(self, node):
op = self.convert_general_op(node)
op.type = MaceOp.PNorm.name
@@ -1010,10 +1079,10 @@ class OnnxConverter(base_converter.ConverterInterface):
op = self.convert_general_op(node)
op.type = MaceOp.Softmax.name
# TODO: add logsoftmax in softmax op
- # if node.op_type == OnnxOpType.LogSoftmax.name:
- # use_log_arg = op.arg.add()
- # use_log_arg.name = 'use_log'
- # use_log_arg.i = 1
+ if node.op_type == OnnxOpType.LogSoftmax.name:
+ use_log_arg = op.arg.add()
+ use_log_arg.name = 'use_log'
+ use_log_arg.i = 1
def convert_splice(self, node):
op = self.convert_general_op(node)
@@ -1104,6 +1173,11 @@ class OnnxConverter(base_converter.ConverterInterface):
else:
op.type = MaceOp.TimeOffset.name
+ chunk_size = node.attrs['chunk_size']
+ chunk_size_arg = op.arg.add()
+ chunk_size_arg.name = 'chunk_size'
+ chunk_size_arg.i = chunk_size
+
offset_arg = op.arg.add()
offset_arg.name = 'offset'
offset_arg.i = offset
diff --git a/mace/python/tools/converter_tool/transformer.py b/mace/python/tools/converter_tool/transformer.py
index 1083e23545767725e2f4e0d9c394d790fd5d0dd3..a097b10aba2691491ccf80a3d8952f4ff1500151 100644
--- a/mace/python/tools/converter_tool/transformer.py
+++ b/mace/python/tools/converter_tool/transformer.py
@@ -1143,6 +1143,7 @@ class Transformer(base_converter.ConverterInterface):
filter.float_data[:] = filter_data.flat
filter.dims[:] = filter_data.shape
arg.i = 1
+ six.print_('transpose matmul weight')
def transpose_filters(self):
net = self._model
@@ -1192,7 +1193,6 @@ class Transformer(base_converter.ConverterInterface):
mace_check(filter_format == DataFormat.HWIO,
"HEXAGON only support HWIO/HWIM filter format.")
else:
- print("Transpose filters to OIHW/MIHW")
# transpose filter to OIHW/MIHW for tensorflow (HWIO/HWIM)
if filter_format == DataFormat.HWIO:
for op in net.op:
@@ -1201,6 +1201,7 @@ class Transformer(base_converter.ConverterInterface):
or op.type == MaceOp.DepthwiseConv2d.name) \
and op.input[1] in self._consts \
and op.input[1] not in transposed_filter:
+ print("Transpose Conv2D/Deconv2D filters to OIHW/MIHW")
filter = self._consts[op.input[1]]
filter_data = np.array(filter.float_data).reshape(
filter.dims)
@@ -1208,9 +1209,13 @@ class Transformer(base_converter.ConverterInterface):
filter.float_data[:] = filter_data.flat
filter.dims[:] = filter_data.shape
transposed_filter.add(op.input[1])
- if (op.type == MaceOp.MatMul.name
- and (ConverterUtil.get_arg(op, MaceKeyword.mace_winograd_filter_transformed) is not None) # noqa
+ if (op.type == MaceOp.MatMul.name and
+ (ConverterUtil.get_arg(
+ op,
+ MaceKeyword.mace_winograd_filter_transformed)
+ is not None) # noqa
and op.input[1] not in transposed_filter):
+ print("Transpose Winograd filters to OIHW/MIHW")
filter = self._consts[op.input[0]]
filter_data = np.array(filter.float_data).reshape(
filter.dims)
@@ -1222,6 +1227,8 @@ class Transformer(base_converter.ConverterInterface):
and op.input[1] not in transposed_filter:
weight = self._consts[op.input[1]]
if len(weight.dims) == 4:
+ print("Transpose FullyConnected filters to"
+ " OIHW/MIHW")
weight_data = np.array(weight.float_data).reshape(
weight.dims)
weight_data = weight_data.transpose(3, 2, 0, 1)
diff --git a/mace/utils/math.h b/mace/utils/math.h
index 0293806c66667d55439b6802e1a8ec3943c1635e..2a806ab433757e3e3c48a2d2ea57dabaad50da32 100644
--- a/mace/utils/math.h
+++ b/mace/utils/math.h
@@ -60,25 +60,22 @@ inline Integer Clamp(Integer in, Integer low, Integer high) {
return std::max(low, std::min(in, high));
}
-template
-inline T ScalarSigmoid(T in) {
- if (in > static_cast(0)) {
- return static_cast(1) / (static_cast(1) + std::exp(-in));
+inline float ScalarSigmoid(float in) {
+ if (in > 0) {
+ return 1 / (1 + std::exp(-in));
} else {
- T x = std::exp(in);
- return x / (x + static_cast(1));
+ float x = std::exp(in);
+ return x / (x + 1.f);
}
}
-template
-inline T ScalarTanh(T in) {
- if (in > static_cast(0)) {
- T inv_expa = std::exp(-in);
- return -static_cast(1) +
- static_cast(2) / (static_cast(1) + inv_expa * inv_expa);
+inline float ScalarTanh(float in) {
+ if (in > 0) {
+ float x = std::exp(-in);
+ return -1.f + 2.f / (1.f + x * x);
} else {
- T x = std::exp(in);
- return x / (x + static_cast(1));
+ float x = std::exp(in);
+ return 1.f - 2.f / (1.f + x * x);
}
}