diff --git a/docs/development/dynamic_lstm.md b/docs/development/dynamic_lstm.md new file mode 100644 index 0000000000000000000000000000000000000000..f7d24a629c02263b42f19f3fb3004e3f4c5c2193 --- /dev/null +++ b/docs/development/dynamic_lstm.md @@ -0,0 +1,24 @@ +Dynamic LSTM +================== + + +The DynamicLSTM in MACE is implemented for Kaldi's time delay RNN models. + +The following pictures explain how to fuse components into a DynamicLSTMCell. + +Before fusing: + +
+how to fuse lstm +
+ + +After fusing: + +
+DynamicLSTM +
+ + +For more details about LSTMNonlinear in Kaldi, +please refer to [LstmNonlinearComponent](http://kaldi-asr.org/doc/nnet-simple-component_8h_source.html#l02164) \ No newline at end of file diff --git a/docs/development/imgs/DynamicLSTM.png b/docs/development/imgs/DynamicLSTM.png new file mode 100644 index 0000000000000000000000000000000000000000..98351988b559aed3690252719181a57101bf866f Binary files /dev/null and b/docs/development/imgs/DynamicLSTM.png differ diff --git a/docs/development/imgs/FuseLSTM.png b/docs/development/imgs/FuseLSTM.png new file mode 100644 index 0000000000000000000000000000000000000000..0cf89d62e71a3d9ff4bf0818556d18604372178c Binary files /dev/null and b/docs/development/imgs/FuseLSTM.png differ diff --git a/mace/benchmark/statistics.cc b/mace/benchmark/statistics.cc index 6ea69be0a67dec96c553ee199c8d42b8fe535ee8..9af7fcb30bf343f82232a5d240b5b78536d3949c 100644 --- a/mace/benchmark/statistics.cc +++ b/mace/benchmark/statistics.cc @@ -131,6 +131,9 @@ int64_t StatMACs(const std::string &op_type, output_shape.end(), 1, std::multiplies()); + } else if (op_type == "DynamicLSTM") { + macs = output_shape[0] * (filter_shape[0] * filter_shape[1] + + output_shape[1] * filter_shape[0] / 4); } return macs; } diff --git a/mace/ops/arm/fp32/gemv.cc b/mace/ops/arm/fp32/gemv.cc index cd0f607fd63f16bb5c99ea0a369dc8423a6bf358..7caa0b5b23d1a9b30d81ce94126bfc2a1a5b82d6 100644 --- a/mace/ops/arm/fp32/gemv.cc +++ b/mace/ops/arm/fp32/gemv.cc @@ -48,8 +48,8 @@ MaceStatus Gemv::Compute(const OpContext *context, Tensor *output) { MACE_UNUSED(context); - MACE_CHECK(output->size() == batch * lhs_height, - "Need resize output tensor before call gemv."); + MACE_CHECK(output->size() >= batch * lhs_height, + "Output buffer is not large enough for computing gemv."); Tensor::MappingGuard lhs_guard(lhs); Tensor::MappingGuard rhs_guard(rhs); diff --git a/mace/ops/common/lstm.cc b/mace/ops/common/lstm.cc new file mode 100644 index 0000000000000000000000000000000000000000..beea3f5b8081584b219cd6c662c4451dfe4cc223 --- /dev/null +++ b/mace/ops/common/lstm.cc @@ -0,0 +1,75 @@ +// Copyright 2019 The MACE Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Details are in +// http://kaldi-asr.org/doc/nnet-simple-component_8h_source.html#l02164 + +#include "mace/ops/common/lstm.h" +#include "mace/utils/math.h" + +namespace mace { +namespace ops { + +void LSTMNonlinearKernel(const float *input_data, + const float *prev_data, + const float *scale_data, + const float *params_data, + bool embed_scales, + index_t params_stride, + index_t cell_dim, + float *output_cell, + float *output_data) { + float i_scale = (embed_scales && scale_data) ? scale_data[0] : 1.0f; + float f_scale = (embed_scales && scale_data) ? scale_data[1] : 1.0f; + float o_scale = (embed_scales && scale_data) ? scale_data[2] : 1.0f; + + if (prev_data == nullptr) { +#pragma omp parallel for schedule(runtime) + for (int c = 0; c < cell_dim; ++c) { + float i_part = input_data[c]; + float c_part = input_data[c + 2 * cell_dim]; + float o_part = input_data[c + 3 * cell_dim]; + float w_oc = params_data[c + params_stride * 2]; + float i_t = ScalarSigmoid(i_part); + float c_t = i_t * i_scale * std::tanh(c_part); + float o_t = ScalarSigmoid(o_part + w_oc * c_t); + float m_t = o_t * o_scale * std::tanh(c_t); + output_cell[c] = c_t; + output_data[c] = m_t; + } + } else { +#pragma omp parallel for schedule(runtime) + for (int c = 0; c < cell_dim; ++c) { + float i_part = input_data[c]; + float f_part = input_data[c + cell_dim]; + float c_part = input_data[c + 2 * cell_dim]; + float o_part = input_data[c + 3 * cell_dim]; + float c_prev = prev_data[c]; + float w_ic = params_data[c]; + float w_fc = params_data[c + params_stride]; + float w_oc = params_data[c + params_stride * 2]; + float i_t = ScalarSigmoid(i_part + w_ic * c_prev); + float f_t = ScalarSigmoid(f_part + w_fc * c_prev); + float c_t = + f_t * f_scale * c_prev + i_t * i_scale * std::tanh(c_part); + float o_t = ScalarSigmoid(o_part + w_oc * c_t); + float m_t = o_t * o_scale * std::tanh(c_t); + output_cell[c] = c_t; + output_data[c] = m_t; + } + } +} + +} // namespace ops +} // namespace mace diff --git a/mace/ops/common/lstm.h b/mace/ops/common/lstm.h new file mode 100644 index 0000000000000000000000000000000000000000..b835386041b6ba86f13818fe4f57c1efb1dff15d --- /dev/null +++ b/mace/ops/common/lstm.h @@ -0,0 +1,37 @@ +// Copyright 2019 The MACE Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MACE_OPS_COMMON_LSTM_H_ +#define MACE_OPS_COMMON_LSTM_H_ + +#include "mace/core/types.h" +namespace mace { +namespace ops { + +void LSTMNonlinearKernel(const float *input_data, + const float *prev_data, + const float *scale_data, + const float *params_data, + bool embed_scales, + index_t params_stride, + index_t cell_dim, + float *output_cell, + float *output_data); + + +} // namespace ops +} // namespace mace + +#endif // MACE_OPS_COMMON_LSTM_H_ + diff --git a/mace/ops/dynamic_lstm.cc b/mace/ops/dynamic_lstm.cc new file mode 100644 index 0000000000000000000000000000000000000000..7fe93f21d6b7831bfe5fba3d21200a21923cdc2e --- /dev/null +++ b/mace/ops/dynamic_lstm.cc @@ -0,0 +1,332 @@ +// Copyright 2018 The MACE Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// This Op is for Fused-LstmNonlinearComponent +// with prev cell states as inputs in Kaldi. +// http://kaldi-asr.org/doc/nnet-simple-component_8h_source.html#l02164 +// More details are in docs/development/dynamic_lstm.md + +#include +#include + +#include "mace/core/operator.h" +#include "mace/ops/common/lstm.h" + +#ifdef MACE_ENABLE_NEON +#include +#include "mace/ops/arm/fp32/gemv.h" +#else +#include "mace/ops/ref/gemv.h" +#endif // MACE_ENABLE_NEON + +namespace mace { +namespace ops { + +template +class DynamicLSTMOp; + +template +class DynamicLSTMOp : public Operation { + public: + explicit DynamicLSTMOp(OpConstructContext *context) + : Operation(context), + prev_out_delay_( + Operation::GetOptionalArg("prev_out_delay", 0)), + prev_cell_delay_( + Operation::GetOptionalArg("prev_cell_delay", 0)), + prev_out_offset_(Operation::GetOptionalArg("prev_out_offset", 0)), + prev_out_dim_(Operation::GetOptionalArg("prev_out_dim", 0)), + prev_cell_dim_(Operation::GetOptionalArg("prev_cell_dim", 0)), + has_bias_a_(Operation::GetOptionalArg("bias_a", 1)), + has_bias_b_(Operation::GetOptionalArg("bias_b", 1)), + scale_(Operation::GetOptionalArg("scale", 1.0f)) {} + + void UpdateCell(float *cell_data, + const index_t cell_dim, + const float scale) { + if (std::abs(scale - 1.f) < 1e-6) + return; + const index_t rounds = cell_dim / 4; +#pragma omp parallel for schedule(runtime) + for (index_t i = 0; i < rounds * 4; i += 4) { +#ifdef MACE_ENABLE_NEON + float32x4_t in_vec = vld1q_f32(cell_data + i); + float32x4_t scale_vec = vdupq_n_f32(scale); + in_vec = vmulq_f32(in_vec, scale_vec); + vst1q_f32(cell_data + i, in_vec); +#else + for (int j = 0; j < 4; ++j) { + cell_data[i + j] *= scale; + } +#endif + } + for (index_t i = rounds * 4; i < cell_dim; ++i) { + cell_data[i] *= scale; + } + } + + void CopyAndUpdateCell(float *src_data, + const index_t cell_dim, + const float scale, + float *cell_data) { + if (std::abs(scale - 1.f) < 1e-6) { + memcpy(cell_data, src_data, cell_dim * sizeof(float)); + return; + } + + const index_t rounds = cell_dim / 4; +#pragma omp parallel for schedule(runtime) + for (index_t i = 0; i < rounds * 4; i += 4) { +#ifdef MACE_ENABLE_NEON + float32x4_t in_vec = vld1q_f32(src_data + i); + float32x4_t scale_vec = vdupq_n_f32(scale); + in_vec = vmulq_f32(in_vec, scale_vec); + vst1q_f32(cell_data + i, in_vec); +#else + for (int j = 0; j < 4; ++j) { + cell_data[i + j] = src_data[i + j] * scale; + } +#endif + } + for (index_t i = rounds * 4; i < cell_dim; ++i) { + cell_data[i] = src_data[i] * scale; + } + } + + MaceStatus Run(OpContext *context) override { + MACE_UNUSED(context); + int max_input_num = 4; + MACE_CHECK(this->InputSize() >= max_input_num, + "DynamicLSTM has at least four inputs."); + MACE_CHECK(prev_cell_delay_ < 0 && prev_out_delay_ < 0); + MACE_CHECK(prev_out_dim_ > 0 && prev_cell_dim_ > 0); + const Tensor *input = this->Input(INPUT); + const Tensor *weights_a = this->Input(WEIGHTS_A); + const Tensor *lstm_params = this->Input(PARAMS); + const Tensor *weights_b = this->Input(WEIGHTS_B); + if (has_bias_a_) { + max_input_num++; + MACE_CHECK(this->InputSize() >= max_input_num, + "The first affine needs a bias input."); + } + const Tensor *bias_a = has_bias_a_ ? + this->Input(max_input_num - 1) : + nullptr; + if (has_bias_b_) { + max_input_num++; + MACE_CHECK(this->InputSize() >= max_input_num, + "The second affine needs a bias input."); + } + const Tensor *bias_b = has_bias_b_ ? + this->Input(max_input_num - 1) : + nullptr; + + const index_t input_rank = input->dim_size(); + MACE_CHECK(input_rank >= 2, + "Dynamic LSTM Cell's input dim size should be >= 2."); + const std::vector &input_shape = input->shape(); + const index_t batch = + std::accumulate(input_shape.begin(), input_shape.end() - 2, 1, + std::multiplies()); + const index_t chunk = input_shape[input_rank - 2]; + const index_t input_dim = input_shape[input_rank - 1]; + + const index_t affine_a_in_dim = input_dim + prev_out_dim_; + const index_t affine_a_out_dim = weights_a->dim(0); + const index_t affine_a_depth = weights_a->dim(1); + MACE_CHECK(affine_a_in_dim == affine_a_depth) + << "affine_a's input_dim:" << affine_a_in_dim + << "!=" << "affine_a's weights' depth:" << affine_a_depth << std::endl; + + const index_t lstm_input_dim = affine_a_out_dim + prev_cell_dim_; + const index_t lstm_cell_dim = lstm_input_dim / 5; + const index_t params_stride = lstm_params->dim(1); + MACE_CHECK(lstm_input_dim == (lstm_cell_dim * 5)); + MACE_CHECK(lstm_params->dim(0) == 3 && + params_stride == lstm_cell_dim && lstm_cell_dim == prev_cell_dim_) + << "lstm params rows:" << lstm_params->dim(0) + << "params_stride:"<< params_stride + << "!=" << "cell_dim:"<< lstm_cell_dim << std::endl; + const index_t affine_b_out_dim = weights_b->dim(0); + const index_t affine_b_depth = weights_b->dim(1); + const index_t affine_b_in_dim = lstm_cell_dim; + MACE_CHECK(affine_b_in_dim == affine_b_depth) + << "affine_b's input_dim:" << affine_b_in_dim + << "!=" << "affine_b's weights' depth:" << affine_b_depth << std::endl; + + const index_t output_dim = affine_b_out_dim; + MACE_CHECK(prev_out_offset_ + prev_out_dim_ <= output_dim); + + const index_t affine_a_in_size = + PadAlignSize(affine_a_in_dim * sizeof(float)); + const index_t affine_a_out_size = + PadAlignSize(affine_a_out_dim * sizeof(float)); + const index_t affine_b_in_size = + PadAlignSize(affine_b_in_dim * sizeof(float)); + const index_t affine_b_out_size = + PadAlignSize(affine_b_out_dim * sizeof(float)); + + const int out_buf_chunk = abs(prev_out_delay_); + const int cell_buf_chunk = abs(prev_cell_delay_); + const index_t out_buf_size = + PadAlignSize(out_buf_chunk * prev_out_dim_ * sizeof(float)); + const index_t cell_buf_size = + PadAlignSize(cell_buf_chunk * prev_cell_dim_ * sizeof(float)); + ScratchBuffer *scratch = context->device()->scratch_buffer(); + scratch->Rewind(); + scratch->GrowSize(affine_a_in_size + affine_a_out_size + + affine_b_in_size + affine_b_out_size + + out_buf_size + cell_buf_size); + + Tensor prev_out(scratch->Scratch(out_buf_size), DT_FLOAT); + prev_out.Reshape({out_buf_chunk, prev_out_dim_}); + float *prev_out_data = prev_out.mutable_data(); + + Tensor prev_cell(scratch->Scratch(cell_buf_size), DT_FLOAT); + prev_cell.Reshape({cell_buf_chunk, prev_cell_dim_}); + float *prev_cell_data = prev_cell.mutable_data(); + + Tensor affine_a_in(scratch->Scratch(affine_a_in_size), DT_FLOAT); + affine_a_in.Reshape({1, affine_a_in_dim}); + float *affine_a_in_data = affine_a_in.mutable_data(); + + Tensor affine_a_out(scratch->Scratch(affine_a_out_size), DT_FLOAT); + affine_a_out.Reshape({1, affine_a_out_dim}); + float *affine_a_out_data = affine_a_out.mutable_data(); + + Tensor affine_b_in(scratch->Scratch(affine_b_in_size), DT_FLOAT); + affine_b_in.Reshape({1, affine_b_in_dim}); + float *affine_b_in_data = affine_b_in.mutable_data(); + + Tensor affine_b_out(scratch->Scratch(affine_b_out_size), DT_FLOAT); + affine_b_out.Reshape({1, affine_b_out_dim}); + float *affine_b_out_data = affine_b_out.mutable_data(); + + Tensor *output = this->Output(OUTPUT); + + std::vector output_shape = input->shape(); + output_shape[1] = output_dim; + + MACE_RETURN_IF_ERROR(output->Resize(output_shape)); + + Tensor::MappingGuard input_guard(input); + Tensor::MappingGuard lstm_params_guard(lstm_params); + Tensor::MappingGuard output_guard(output); + const float *input_data = input->data(); + const float *lstm_params_data = lstm_params->data(); + float *output_data = output->mutable_data(); + + for (int b = 0; b < batch; ++b) { + int prev_out_idx = prev_out_delay_; + int prev_cell_idx = prev_cell_delay_; + prev_cell.Clear(); + prev_out.Clear(); + affine_a_in.Clear(); + affine_a_out.Clear(); + affine_b_in.Clear(); + affine_b_out.Clear(); + for (int i = 0; i < chunk; ++i) { + // Append + memcpy(affine_a_in_data, input_data, input_dim * sizeof(float)); + if (prev_out_idx >= 0) { + memcpy(affine_a_in_data + input_dim, + prev_out_data + prev_out_idx % out_buf_chunk * prev_out_dim_, + prev_out_dim_ * sizeof(float)); + } + // Affine + gemv_.Compute(context, + weights_a, + &affine_a_in, + bias_a, + 1, + affine_a_out_dim, + affine_a_depth, + false, + false, + &affine_a_out); + // Prepare LSTMNonlinear input and output pointer + float *prev_cell_ptr = + prev_cell_idx < 0 ? nullptr : + prev_cell_data + prev_cell_idx % cell_buf_chunk * prev_cell_dim_; + float *curr_cell_ptr = + prev_cell_data + i % cell_buf_chunk * prev_cell_dim_; + // LSTMNonlinear + LSTMNonlinearKernel(affine_a_out_data, + prev_cell_ptr, + nullptr, + lstm_params_data, + false, + params_stride, + lstm_cell_dim, + curr_cell_ptr, + affine_b_in_data); + UpdateCell(curr_cell_ptr, prev_cell_dim_, scale_); + // Affine + gemv_.Compute(context, + weights_b, + &affine_b_in, + bias_b, + 1, + affine_b_out_dim, + affine_b_depth, + false, + false, + &affine_b_out); + // Output + memcpy(output_data, + affine_b_out_data, + output_dim * sizeof(float)); + // Update + float *curr_out_ptr = prev_out_data + i % out_buf_chunk * prev_out_dim_; + CopyAndUpdateCell(affine_b_out_data + prev_out_offset_, + prev_out_dim_, + scale_, + curr_out_ptr); + input_data += input_dim; + output_data += output_dim; + prev_out_idx++; + prev_cell_idx++; + } + } + + return MaceStatus::MACE_SUCCESS; + } + + private: + int prev_out_delay_; + int prev_cell_delay_; + int prev_out_offset_; + int prev_out_dim_; + int prev_cell_dim_; + int has_bias_a_; + int has_bias_b_; + float scale_; + +#ifdef MACE_ENABLE_NEON + arm::fp32::Gemv gemv_; +#else + ref::Gemv gemv_; +#endif // MACE_ENABLE_NEON + + MACE_OP_INPUT_TAGS(INPUT, WEIGHTS_A, PARAMS, WEIGHTS_B); + MACE_OP_OUTPUT_TAGS(OUTPUT); +}; + +void RegisterDynamicLSTM(OpRegistryBase *op_registry) { + MACE_REGISTER_OP(op_registry, "DynamicLSTM", DynamicLSTMOp, + DeviceType::CPU, float); +} + +} // namespace ops +} // namespace mace diff --git a/mace/ops/dynamic_lstm_benchmark.cc b/mace/ops/dynamic_lstm_benchmark.cc new file mode 100644 index 0000000000000000000000000000000000000000..00cd11bfd28fe4108847021dd11a40a237d15125 --- /dev/null +++ b/mace/ops/dynamic_lstm_benchmark.cc @@ -0,0 +1,124 @@ +// Copyright 2018 The MACE Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mace/benchmark/statistics.h" +#include "mace/core/testing/test_benchmark.h" +#include "mace/ops/lstmcell_test_util.h" +#include "mace/ops/ops_test_util.h" + +namespace mace { +namespace ops { +namespace test { + +namespace { +template +void DynamicLSTM(int iters, + int chunk, + int input_dim, + int output_dim, + int cell_dim, + int prev_out_dim, + int delay) { + mace::testing::StopTiming(); + + OpsTestNet net; + MACE_CHECK(prev_out_dim <= output_dim); + const int weights_a_rows = 4 * cell_dim; + const int weights_a_cols = input_dim + prev_out_dim; + const int bias_a_rows = weights_a_rows; + + const int weights_b_rows = output_dim; + const int weights_b_cols = cell_dim; + const int bias_b_rows = weights_b_rows; + + // Add input data + net.AddRandomInput("Input", {chunk, input_dim}); + net.AddRandomInput("Weight_A", + {weights_a_rows, weights_a_cols}, + true); + net.AddRandomInput("Params", + {3, cell_dim}, + true); + net.AddRandomInput("Weight_B", + {weights_b_rows, weights_b_cols}, + true); + net.AddRandomInput("Bias_A", {bias_a_rows}, true); + net.AddRandomInput("Bias_B", {bias_b_rows}, true); + + if (D == DeviceType::CPU) { + OpDefBuilder("DynamicLSTM", "DynamicLSTMTest") + .Input("Input") + .Input("Weight_A") + .Input("Params") + .Input("Weight_B") + .Input("Bias_A") + .Input("Bias_B") + .Output("Output") + .AddIntArg("prev_out_delay", -delay) + .AddIntArg("prev_cell_delay", -delay) + .AddIntArg("prev_out_dim", prev_out_dim) + .AddIntArg("prev_cell_dim", cell_dim) + .AddIntArg("T", static_cast(DataTypeToEnum::value)) + .Finalize(net.NewOperatorDef()); + } else { + MACE_NOT_IMPLEMENTED; + } + + // Warm-up + for (int i = 0; i < 5; ++i) { + net.RunOp(D); + } + net.Sync(); + + mace::testing::StartTiming(); + while (iters--) { + net.RunOp(D); + } + net.Sync(); +} +} // namespace + +#define MACE_BM_DYNAMIC_LSTM_MACRO( \ + N, ID, OD, CD, POD, DELAY, TYPE, DEVICE) \ + static void \ + MACE_BM_DYNAMIC_LSTM_##N##_##ID##_##OD##_##CD##_##POD##_##DELAY##_##TYPE\ + ##_##DEVICE( \ + int iters) { \ + int64_t wa_size = 4 * CD * (ID + POD); \ + int64_t wb_size = OD * CD; \ + int64_t prev_size = DELAY * (POD + CD); \ + int64_t in_out_size = N * (ID + OD); \ + int64_t bias_size = 4 * CD + OD; \ + const int64_t macs = static_cast(iters) * \ + mace::benchmark::StatMACs("DynamicLSTM", {4 * CD, ID + POD}, {N, OD});\ + const int64_t tot = static_cast(iters) * (in_out_size + prev_size\ + + wa_size + wb_size + bias_size); \ + mace::testing::MacsProcessed(macs); \ + mace::testing::BytesProcessed(tot * (sizeof(TYPE))); \ + DynamicLSTM(iters, N, ID, OD, CD, POD, DELAY); \ + } \ + MACE_BENCHMARK( \ + MACE_BM_DYNAMIC_LSTM_##N##_##ID##_##OD##_##CD##_##POD##_##DELAY \ + ##_##TYPE##_##DEVICE) + +#define MACE_BM_DYNAMIC_LSTM(N, ID, OD, CD, POD, DELAY) \ + MACE_BM_DYNAMIC_LSTM_MACRO(N, ID, OD, CD, POD, DELAY, float, CPU); + +MACE_BM_DYNAMIC_LSTM(50, 184, 128, 184, 64, 3); +MACE_BM_DYNAMIC_LSTM(50, 64, 256, 64, 128, 3); +MACE_BM_DYNAMIC_LSTM(80, 64, 256, 128, 64, 3); + +} // namespace test +} // namespace ops +} // namespace mace diff --git a/mace/ops/lstm_nonlinear.cc b/mace/ops/lstm_nonlinear.cc new file mode 100644 index 0000000000000000000000000000000000000000..745c4d79674c6e2becc2eb49b2d855a2819a0e15 --- /dev/null +++ b/mace/ops/lstm_nonlinear.cc @@ -0,0 +1,113 @@ +// Copyright 2018 The MACE Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// This Op is for LstmNonlinearComponent in Kaldi. +// http://kaldi-asr.org/doc/nnet-simple-component_8h_source.html#l02164 + +#include +#include + +#include "mace/core/operator.h" +#include "mace/ops/common/lstm.h" + +namespace mace { +namespace ops { + +template +class LSTMNonlinearOp; + +template +class LSTMNonlinearOp : public Operation { + public: + explicit LSTMNonlinearOp(OpConstructContext *context) + : Operation(context) {} + + MaceStatus Run(OpContext *context) override { + MACE_UNUSED(context); + const Tensor *input = this->Input(INPUT); + MACE_CHECK(this->InputSize() >= 2, + "LSTMNonlinear should have at least 2 inputs."); + const Tensor *params = this->Input(PARAMS); + Tensor *output = this->Output(OUTPUT); + + MACE_CHECK(input->dim_size() >= 2) + << "The input dim size should >= 2"; + MACE_CHECK(params->dim_size() == 2) + << "The params dim size should be 2"; + return Compute(input, params, output); + } + + MaceStatus Compute(const Tensor *input, + const Tensor *params, + Tensor *output) { + const std::vector &input_shape = input->shape(); + const std::vector ¶ms_shape = params->shape(); + + const index_t num_rows = + std::accumulate(input_shape.begin(), input_shape.end() - 1, 1, + std::multiplies()); + index_t rank = input->dim_size(); + const index_t input_cols = input_shape[rank - 1]; + const index_t cell_dim = input_cols / 5; + bool embed_scales = input_cols == cell_dim * 5 + 3; + const index_t params_stride = params_shape[1]; + + MACE_CHECK(input_cols == (cell_dim * 5) || embed_scales); + MACE_CHECK(params_shape[0] == 3 && params_shape[1] == cell_dim); + + const index_t output_dim = cell_dim * 2; + std::vector output_shape = input->shape(); + output_shape[rank - 1] = output_dim; + MACE_RETURN_IF_ERROR(output->Resize(output_shape)); + + Tensor::MappingGuard input_guard(input); + Tensor::MappingGuard params_guard(params); + Tensor::MappingGuard output_guard(output); + const float *input_data = input->data(); + const float *params_data = params->data(); + float *output_data = output->mutable_data(); +#pragma omp parallel for schedule(runtime) + for (int r = 0; r < num_rows; ++r) { + const float *input_row = input_data + r * input_cols; + const float *prev_row = input_row + 4 * cell_dim; + const float *scale_data = + embed_scales ? prev_row + cell_dim : nullptr; + float *output_cell = output_data + r * output_dim; + float *output_row = output_cell + cell_dim; + LSTMNonlinearKernel(input_row, + prev_row, + scale_data, + params_data, + embed_scales, + params_stride, + cell_dim, + output_cell, + output_row); + } + + return MaceStatus::MACE_SUCCESS; + } + + protected: + MACE_OP_INPUT_TAGS(INPUT, PARAMS); + MACE_OP_OUTPUT_TAGS(OUTPUT); +}; + +void RegisterLSTMNonlinear(OpRegistryBase *op_registry) { + MACE_REGISTER_OP(op_registry, "LSTMNonlinearOp", LSTMNonlinearOp, + DeviceType::CPU, float); +} + +} // namespace ops +} // namespace mace diff --git a/mace/ops/lstm_nonlinear_benchmark.cc b/mace/ops/lstm_nonlinear_benchmark.cc new file mode 100644 index 0000000000000000000000000000000000000000..6ca881c57f4fa768368a6977fbe89ac276dd83c1 --- /dev/null +++ b/mace/ops/lstm_nonlinear_benchmark.cc @@ -0,0 +1,85 @@ +// Copyright 2018 The MACE Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mace/core/testing/test_benchmark.h" +#include "mace/ops/lstmcell_test_util.h" +#include "mace/ops/ops_test_util.h" + +namespace mace { +namespace ops { +namespace test { + +namespace { +template +void LSTMNonlinear(int iters, + int batch, + int input_dim) { + mace::testing::StopTiming(); + + OpsTestNet net; + + int cell_dim = input_dim / 5; + + // Add input data + net.AddRandomInput("Input", {batch, input_dim}); + net.AddRandomInput("Params", + {3, cell_dim}, + true); + if (D == DeviceType::CPU) { + OpDefBuilder("LSTMNonlinear", "LSTMNonlinearTest") + .Input("Input") + .Input("Params") + .Output("Output") + .AddIntArg("T", static_cast(DataTypeToEnum::value)) + .Finalize(net.NewOperatorDef()); + } else { + MACE_NOT_IMPLEMENTED; + } + + // Warm-up + for (int i = 0; i < 5; ++i) { + net.RunOp(D); + } + net.Sync(); + + mace::testing::StartTiming(); + while (iters--) { + net.RunOp(D); + } + net.Sync(); +} +} // namespace + +#define MACE_BM_LSTM_NONLIN_MACRO(N, IN_DIM, TYPE, DEVICE) \ + static void \ + MACE_BM_LSTM_NONLIN_##N##_##IN_DIM##_##TYPE##_##DEVICE(\ + int iters) { \ + const int64_t tot = \ + static_cast(iters) * (N * IN_DIM + 3 * (IN_DIM / 5));\ + mace::testing::BytesProcessed(tot * (sizeof(TYPE))); \ + LSTMNonlinear(iters, N, IN_DIM); \ + } \ + MACE_BENCHMARK( \ + MACE_BM_LSTM_NONLIN_##N##_##IN_DIM##_##TYPE##_##DEVICE) + +#define MACE_BM_LSTM_NONLIN(N, IN_DIM) \ + MACE_BM_LSTM_NONLIN_MACRO(N, IN_DIM, float, CPU); + +MACE_BM_LSTM_NONLIN(50, 200); +MACE_BM_LSTM_NONLIN(50, 920); +MACE_BM_LSTM_NONLIN(80, 640); + +} // namespace test +} // namespace ops +} // namespace mace diff --git a/mace/ops/opencl/buffer/softmax.h b/mace/ops/opencl/buffer/softmax.h index db3f2800cefc276fda229cb7d27950f79c19fcea..3ab6a7cef1bd1d760ea70e1409f687d664f51996 100644 --- a/mace/ops/opencl/buffer/softmax.h +++ b/mace/ops/opencl/buffer/softmax.h @@ -32,12 +32,16 @@ namespace buffer { template class SoftmaxKernel : public OpenCLSoftmaxKernel { public: + explicit SoftmaxKernel(bool use_log) + : use_log_(use_log) {} + MaceStatus Compute( OpContext *context, const Tensor *logits, Tensor *output) override; private: + bool use_log_; cl::Kernel kernel_; uint32_t kwg_size_; std::vector input_shape_; @@ -88,6 +92,7 @@ MaceStatus SoftmaxKernel::Compute( built_options.emplace("-DIN_DATA_TYPE=" + DtToCLDt(logits->dtype())); built_options.emplace("-DOUT_DATA_TYPE=" + DtToCLDt(dt)); built_options.emplace("-DDATA_TYPE=" + DtToUpCompatibleCLDt(dt)); + if (use_log_) built_options.emplace("-DUSE_LOG"); MACE_RETURN_IF_ERROR(runtime->BuildKernel("softmax_buffer", kernel_name, built_options, &kernel_)); diff --git a/mace/ops/opencl/cl/softmax.cl b/mace/ops/opencl/cl/softmax.cl index 39f8c89fe6e7b94bcfb803e1e6da1c562c86f320..c4babfd23be9ae7a23a12ee74c83c1b16d5dc8a5 100644 --- a/mace/ops/opencl/cl/softmax.cl +++ b/mace/ops/opencl/cl/softmax.cl @@ -73,13 +73,25 @@ __kernel void softmax(OUT_OF_RANGE_PARAMS switch(exceeded) { case 1: data.z = native_exp(data.z) / sum; +#ifdef USE_LOG + data.z = native_log(data.z); +#endif case 2: data.y = native_exp(data.y) / sum; +#ifdef USE_LOG + data.y = native_log(data.y); +#endif case 3: data.x = native_exp(data.x) / sum; +#ifdef USE_LOG + data.x = native_log(data.x); +#endif break; default: data = native_exp(data) / sum; +#ifdef USE_LOG + data = native_log(data); +#endif } WRITE_IMAGET(output, (int2)(pos, hb_idx), data); diff --git a/mace/ops/opencl/cl/softmax_buffer.cl b/mace/ops/opencl/cl/softmax_buffer.cl index 2a96a237d91c9d05fe7516527a331c681f62a3be..78fae30ff2e1173fb75c5eca5effee095d35e8b8 100644 --- a/mace/ops/opencl/cl/softmax_buffer.cl +++ b/mace/ops/opencl/cl/softmax_buffer.cl @@ -75,14 +75,26 @@ __kernel void softmax(BUFFER_OUT_OF_RANGE_PARAMS switch(remain_chan) { case 3: output[offset + 2] = native_exp(CONVERT(input[offset + 2]) - max_value) / sum; +#ifdef USE_LOG + output[offset + 2] = native_log(output[offset + 2]); +#endif case 2: output[offset + 1] = native_exp(CONVERT(input[offset + 1]) - max_value) / sum; +#ifdef USE_LOG + output[offset + 1] = native_log(output[offset + 1]); +#endif case 1: output[offset] = native_exp(CONVERT(input[offset]) - max_value) / sum; +#ifdef USE_LOG + output[offset] = native_log(output[offset]); +#endif } } else { data = CONVERT4(vload4(0, input + offset)); data = native_exp(data - max_value) / sum; +#ifdef USE_LOG + data = native_log(data) +#endif VSTORE4(CONVERT_TO(data, OUT_DATA_TYPE4), output, offset); } } diff --git a/mace/ops/opencl/image/softmax.h b/mace/ops/opencl/image/softmax.h index 2bbf1aa31d84d4d011bd5ab1875779f51ea236f3..3aa84bb5091066bff8565d3428fca7ebe4badafd 100644 --- a/mace/ops/opencl/image/softmax.h +++ b/mace/ops/opencl/image/softmax.h @@ -59,12 +59,16 @@ inline std::vector LocalWS(OpenCLRuntime *runtime, template class SoftmaxKernel : public OpenCLSoftmaxKernel { public: + explicit SoftmaxKernel(bool use_log) + : use_log_(use_log) {} + MaceStatus Compute( OpContext *context, const Tensor *logits, Tensor *output) override; private: + bool use_log_; cl::Kernel kernel_; uint32_t kwg_size_; std::vector input_shape_; @@ -114,6 +118,8 @@ MaceStatus SoftmaxKernel::Compute( auto dt = DataTypeToEnum::value; built_options.emplace("-DDATA_TYPE=" + DtToUpCompatibleCLDt(dt)); built_options.emplace("-DCMD_DATA_TYPE=" + DtToUpCompatibleCLCMDDt(dt)); + if (use_log_) + built_options.emplace("-DUSE_LOG"); MACE_RETURN_IF_ERROR(runtime->BuildKernel("softmax", kernel_name, built_options, &kernel_)); diff --git a/mace/ops/ops_registry.cc b/mace/ops/ops_registry.cc index 7fc3545883ee855a578c001ac3ff75ff574261b6..26bf046391f6e3ad156a3366c75ddbd1e515c9f2 100644 --- a/mace/ops/ops_registry.cc +++ b/mace/ops/ops_registry.cc @@ -34,6 +34,7 @@ extern void RegisterDeconv2D(OpRegistryBase *op_registry); extern void RegisterDepthToSpace(OpRegistryBase *op_registry); extern void RegisterDepthwiseConv2d(OpRegistryBase *op_registry); extern void RegisterDepthwiseDeconv2d(OpRegistryBase *op_registry); +extern void RegisterDynamicLSTM(OpRegistryBase *op_registry); extern void RegisterEltwise(OpRegistryBase *op_registry); extern void RegisterExpandDims(OpRegistryBase *op_registry); extern void RegisterFill(OpRegistryBase *op_registry); @@ -42,9 +43,11 @@ extern void RegisterGather(OpRegistryBase *op_registry); extern void RegisterIdentity(OpRegistryBase *op_registry); extern void RegisterInferConv2dShape(OpRegistryBase *op_registry); extern void RegisterLocalResponseNorm(OpRegistryBase *op_registry); +extern void RegisterLSTMNonlinear(OpRegistryBase *op_registry); extern void RegisterMatMul(OpRegistryBase *op_registry); extern void RegisterOneHot(OpRegistryBase *op_registry); extern void RegisterPad(OpRegistryBase *op_registry); +extern void RegisterPadContext(OpRegistryBase *op_registry); extern void RegisterPNorm(OpRegistryBase *op_registry); extern void RegisterPooling(OpRegistryBase *op_registry); extern void RegisterReduce(OpRegistryBase *op_registry); @@ -68,7 +71,6 @@ extern void RegisterStack(OpRegistryBase *op_registry); extern void RegisterStridedSlice(OpRegistryBase *op_registry); extern void RegisterSumGroup(OpRegistryBase *op_registry); extern void RegisterTargetRMSNorm(OpRegistryBase *op_registry); -extern void RegisterTimeOffset(OpRegistryBase *op_registry); extern void RegisterTranspose(OpRegistryBase *op_registry); extern void RegisterUnstack(OpRegistryBase *op_registry); @@ -102,6 +104,7 @@ OpRegistry::OpRegistry() : OpRegistryBase() { ops::RegisterDepthToSpace(this); ops::RegisterDepthwiseConv2d(this); ops::RegisterDepthwiseDeconv2d(this); + ops::RegisterDynamicLSTM(this); ops::RegisterEltwise(this); ops::RegisterExpandDims(this); ops::RegisterFill(this); @@ -110,9 +113,11 @@ OpRegistry::OpRegistry() : OpRegistryBase() { ops::RegisterIdentity(this); ops::RegisterInferConv2dShape(this); ops::RegisterLocalResponseNorm(this); + ops::RegisterLSTMNonlinear(this); ops::RegisterMatMul(this); ops::RegisterOneHot(this); ops::RegisterPad(this); + ops::RegisterPadContext(this); ops::RegisterPNorm(this); ops::RegisterPooling(this); ops::RegisterReduce(this); @@ -136,7 +141,6 @@ OpRegistry::OpRegistry() : OpRegistryBase() { ops::RegisterSqueeze(this); ops::RegisterSumGroup(this); ops::RegisterTargetRMSNorm(this); - ops::RegisterTimeOffset(this); ops::RegisterTranspose(this); ops::RegisterUnstack(this); diff --git a/mace/ops/time_offset.cc b/mace/ops/pad_context.cc similarity index 53% rename from mace/ops/time_offset.cc rename to mace/ops/pad_context.cc index d9343fc327438a965fe4b3e98a583783a6d4993a..6c463ec9830b2e22e234cef6e4ec7eddc61d9906 100644 --- a/mace/ops/time_offset.cc +++ b/mace/ops/pad_context.cc @@ -25,14 +25,15 @@ namespace mace { namespace ops { template -class TimeOffsetOp; +class PadContextOp; template -class TimeOffsetOp : public Operation { +class PadContextOp : public Operation { public: - explicit TimeOffsetOp(OpConstructContext *context) + explicit PadContextOp(OpConstructContext *context) : Operation(context), - offset_(Operation::GetOptionalArg("offset", 0)) {} + left_context_(Operation::GetOptionalArg("left_context", 0)), + right_context_(Operation::GetOptionalArg("right_context", 0)) {} MaceStatus Run(OpContext *context) override { MACE_UNUSED(context); @@ -41,27 +42,38 @@ class TimeOffsetOp : public Operation { index_t rank = input->dim_size(); MACE_CHECK(rank >= 2, "input's rank should >= 2."); + MACE_CHECK(left_context_ > 0 && right_context_ > 0, + "left context and right context should be greater than zero"); const std::vector &input_shape = input->shape(); const index_t batch = std::accumulate(input_shape.begin(), input_shape.end() - 2, 1, std::multiplies()); - const index_t frames = input_shape[rank - 2]; - const index_t input_dim = input_shape[rank - 1]; - MACE_RETURN_IF_ERROR(output->ResizeLike(input)); + const index_t chunk = input_shape[rank - 2]; + const index_t dim = input_shape[rank - 1]; + const index_t output_chunk = chunk + left_context_ + right_context_; + std::vector output_shape = input->shape(); + output_shape[rank - 2] = output_chunk; + MACE_RETURN_IF_ERROR(output->Resize(output_shape)); Tensor::MappingGuard input_guard(input); Tensor::MappingGuard output_guard(output); const T *input_data = input->data(); T *output_data = output->mutable_data(); -#pragma omp parallel for collapse(2) schedule(runtime) for (index_t i = 0; i < batch; ++i) { - for (index_t j = 0; j < frames; ++j) { - index_t time_index = offset_ + j; - index_t index = Clamp(time_index, 0, frames - 1); - T *output_base = output_data + (i * frames + j) * input_dim; - const T *input_base = input_data + (i * frames + index) * input_dim; - memcpy(output_base, input_base, input_dim * sizeof(T)); + T *out_base = output_data + i * output_chunk * dim; + const T *in_base = input_data + i * chunk * dim; +#pragma omp parallel for schedule(runtime) + for (index_t j = 0; j < left_context_; ++j) { + memcpy(out_base + j * dim, in_base, dim * sizeof(T)); + } + out_base = out_base + left_context_ * dim; + memcpy(out_base, in_base, chunk * dim * sizeof(T)); + out_base = out_base + chunk * dim; + in_base = in_base + (chunk -1) * dim; +#pragma omp parallel for schedule(runtime) + for (index_t j = 0; j < right_context_; ++j) { + memcpy(out_base + j * dim, in_base, dim * sizeof(T)); } } @@ -69,11 +81,12 @@ class TimeOffsetOp : public Operation { } private: - int offset_; + int left_context_; + int right_context_; }; -void RegisterTimeOffset(OpRegistryBase *op_registry) { - MACE_REGISTER_OP(op_registry, "TimeOffset", TimeOffsetOp, +void RegisterPadContext(OpRegistryBase *op_registry) { + MACE_REGISTER_OP(op_registry, "PadContext", PadContextOp, DeviceType::CPU, float); } diff --git a/mace/ops/pad_context_benchmark.cc b/mace/ops/pad_context_benchmark.cc new file mode 100644 index 0000000000000000000000000000000000000000..3384f79d9f7aef9d18785cf89511de64f5ca609d --- /dev/null +++ b/mace/ops/pad_context_benchmark.cc @@ -0,0 +1,79 @@ +// Copyright 2018 The MACE Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mace/core/testing/test_benchmark.h" +#include "mace/ops/ops_test_util.h" + +namespace mace { +namespace ops { +namespace test { + +namespace { +template +void PadContextBM(int iters, + const std::vector &input_shape, + const int left_context, + const int right_context) { + mace::testing::StopTiming(); + + // Construct graph + OpsTestNet net; + + net.AddRandomInput("Input", input_shape); + + OpDefBuilder("PadContext", "PadContextBM") + .Input("Input") + .Output("Output") + .AddIntArg("left_context", left_context) + .AddIntArg("right_context", right_context) + .AddIntArg("T", static_cast(DataTypeToEnum::value)) + .Finalize(net.NewOperatorDef()); + + // Warm-up + for (int i = 0; i < 5; ++i) { + net.RunOp(D); + net.Sync(); + } + + mace::testing::StartTiming(); + while (iters--) { + net.RunOp(D); + net.Sync(); + } +} +} // namespace + +#define MACE_BM_PAD_CONTEXT_MACRO(N, H, W, L, R, TYPE, DEVICE) \ + static void \ + MACE_BM_PAD_CONTEXT_##N##_##H##_##W##_##L##_##R##_##TYPE##_##DEVICE( \ + int iters) { \ + const int64_t tot = static_cast(iters) * N * H * W; \ + mace::testing::BytesProcessed(tot *(sizeof(TYPE))); \ + PadContextBM(iters, {N, H, W}, L, R); \ + } \ + MACE_BENCHMARK( \ + MACE_BM_PAD_CONTEXT_##N##_##H##_##W##_##L##_##R##_##TYPE##_##DEVICE) + +#define MACE_BM_PAD_CONTEXT(N, H, W, L, R) \ + MACE_BM_PAD_CONTEXT_MACRO(N, H, W, L, R, float, CPU); + +MACE_BM_PAD_CONTEXT(1, 32, 32, 5, 5); +MACE_BM_PAD_CONTEXT(2, 32, 32, 7, 7); +MACE_BM_PAD_CONTEXT(1, 32, 32, 3, 3); +MACE_BM_PAD_CONTEXT(1, 128, 128, 9, 9); +MACE_BM_PAD_CONTEXT(3, 128, 128, 7, 7); + +} // namespace test +} // namespace ops +} // namespace mace diff --git a/mace/ops/pad_context_test.cc b/mace/ops/pad_context_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..72f7ba576c3e8b655a3da61302147ab91aeb0fdf --- /dev/null +++ b/mace/ops/pad_context_test.cc @@ -0,0 +1,87 @@ +// Copyright 2018 The MACE Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mace/ops/ops_test_util.h" + +namespace mace { +namespace ops { +namespace test { + +class PadContextOpTest : public OpsTestBase {}; + +namespace { +template +void TestPadContext(const std::vector &input_shape, + const std::vector &input, + const int left_context, + const int right_context, + const std::vector &output_shape, + const std::vector &output) { + OpsTestNet net; + net.AddInputFromArray(MakeString("Input"), + input_shape, + input); + + OpDefBuilder("PadContext", "PadContextTest") + .Input("Input") + .Output("Output") + .AddIntArg("left_context", left_context) + .AddIntArg("right_context", right_context) + .Finalize(net.NewOperatorDef()); + + net.RunOp(); + + auto expected = net.CreateTensor(output_shape, output); + ExpectTensorNear(*expected, *net.GetOutput("Output"), 1e-5); +} +} // namespace + +TEST_F(PadContextOpTest, Simple2Dim) { + TestPadContext( + {3, 5}, + {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, + 2, 3, {8, 5}, + {1, 2, 3, 4, 5, + 1, 2, 3, 4, 5, + 1, 2, 3, 4, 5, + 6, 7, 8, 9, 10, + 11, 12, 13, 14, 15, + 11, 12, 13, 14, 15, + 11, 12, 13, 14, 15, + 11, 12, 13, 14, 15}); +} + +TEST_F(PadContextOpTest, Simple3Dim) { + TestPadContext( + {2, 3, 5}, + {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, + 1, 2, {2, 6, 5}, + {1, 2, 3, 4, 5, + 1, 2, 3, 4, 5, + 6, 7, 8, 9, 10, + 11, 12, 13, 14, 15, + 11, 12, 13, 14, 15, + 11, 12, 13, 14, 15, + 1, 2, 3, 4, 5, + 1, 2, 3, 4, 5, + 6, 7, 8, 9, 10, + 11, 12, 13, 14, 15, + 11, 12, 13, 14, 15, + 11, 12, 13, 14, 15}); +} + +} // namespace test +} // namespace ops +} // namespace mace diff --git a/mace/ops/pnorm.cc b/mace/ops/pnorm.cc index 8742a3b4492cb36aab4deece867a2021c4afd106..6964c6810bac50e59350410f009ac85c85f44ed6 100644 --- a/mace/ops/pnorm.cc +++ b/mace/ops/pnorm.cc @@ -48,8 +48,6 @@ class PNormOp : public Operation { const Tensor *input = this->Input(0); Tensor *output = this->Output(0); - - const std::vector &input_shape = input->shape(); const index_t dim_size = input_shape.size(); MACE_CHECK(dim_size >= 1, "PNorm only supports input dim size >= 1"); diff --git a/mace/ops/pnorm_benchmark.cc b/mace/ops/pnorm_benchmark.cc index e3af765cd22f3589abd602dc6e28cd96acc2ee0f..e1efd0eb4052981188ae703f8044913c54afe671 100644 --- a/mace/ops/pnorm_benchmark.cc +++ b/mace/ops/pnorm_benchmark.cc @@ -57,7 +57,7 @@ void PNormBenchmark(int iters, int n, int h, int w, int p, int ow) { int iters) { \ const int64_t tot = static_cast(iters) * N * H * W; \ mace::testing::BytesProcessed(tot *(sizeof(TYPE))); \ - PNormBenchmark(iters, N, H, W, P, OW); \ + PNormBenchmark(iters, N, H, W, P, OW); \ } \ MACE_BENCHMARK( \ MACE_BM_PNORM_##N##_##H##_##W##_##P##_##OW##_##TYPE##_##DEVICE) diff --git a/mace/ops/slice.cc b/mace/ops/slice.cc index f38a2a32a861a2ca20882268bc98d96fca55d6d7..f990912d0ce1f02ea65ab95d2334cf411aee2750 100644 --- a/mace/ops/slice.cc +++ b/mace/ops/slice.cc @@ -40,19 +40,18 @@ class SliceOp : public Operation { const index_t rank = input->dim_size(); MACE_CHECK(rank >= 1) << "The input dim size should >= 1"; + const index_t input_dim = input->dim(rank - 1); MACE_CHECK(starts_.size() == 1 && ends_.size() == 1 && axes_.size() == 1, "only support slicing at one axis."); MACE_CHECK(axes_[0] == -1 || axes_[0] == rank - 1, "only support slicing at the last axis."); - const index_t input_dim = input->dim(rank - 1); + MACE_CHECK(starts_[0] < input_dim && starts_[0] >= 0 + && ends_[0] >= 0 + && ends_[0] <= input_dim) + << "The starts and ends caused over range error."; const index_t offset = starts_[0]; const index_t output_dim = ends_[0] - starts_[0]; - MACE_CHECK(output_dim >= 0, "output_dim should >= 0"); - MACE_CHECK(starts_[0] < input_dim - && output_dim <= input_dim - && ends_[0] <= input_dim) - << "The starts and ends caused over range error."; const index_t frames = std::accumulate(input->shape().begin(), input->shape().end() - 1, 1, diff --git a/mace/ops/slice_benchmark.cc b/mace/ops/slice_benchmark.cc new file mode 100644 index 0000000000000000000000000000000000000000..fd2a383934964c37d132ec9fc4fe6dd731f490b6 --- /dev/null +++ b/mace/ops/slice_benchmark.cc @@ -0,0 +1,76 @@ +// Copyright 2018 The MACE Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mace/core/testing/test_benchmark.h" +#include "mace/ops/ops_test_util.h" + +namespace mace { +namespace ops { +namespace test { + +namespace { +template +void BMSliceHelper(int iters, + const std::vector &input_shape, + const int offset, + const int output_dim) { + mace::testing::StopTiming(); + // Construct graph + OpsTestNet net; + net.AddRandomInput("Input", input_shape); + OpDefBuilder("Slice", "SliceBM") + .Input("Input") + .Output("Output") + .AddIntArg("offset", offset) + .AddIntArg("output_dim", output_dim) + .AddIntArg("T", static_cast(DataTypeToEnum::value)) + .Finalize(net.NewOperatorDef()); + + // Warm-up + for (int i = 0; i < 5; ++i) { + net.RunOp(D); + net.Sync(); + } + + mace::testing::StartTiming(); + while (iters--) { + net.RunOp(D); + net.Sync(); + } +} +} // namespace + +#define MACE_BM_SLICE_MACRO(N, H, W, S, D, TYPE, DEVICE) \ + static void \ + MACE_BM_SLICE_##N##_##H##_##W##_##S##_##D##_##TYPE##_##DEVICE( \ + int iters) { \ + const int64_t tot = static_cast(iters) * N * H * W; \ + mace::testing::BytesProcessed(tot *(sizeof(TYPE))); \ + BMSliceHelper(iters, {N, H, W}, S, D); \ + } \ + MACE_BENCHMARK( \ + MACE_BM_SLICE_##N##_##H##_##W##_##S##_##D##_##TYPE##_##DEVICE) + +#define MACE_BM_SLICE(N, H, W, S, D) \ + MACE_BM_SLICE_MACRO(N, H, W, S, D, float, CPU); + +MACE_BM_SLICE(1, 32, 32, 5, 5); +MACE_BM_SLICE(1, 32, 32, 7, 5); +MACE_BM_SLICE(1, 32, 32, 3, 20); +MACE_BM_SLICE(1, 128, 128, 9, 100); +MACE_BM_SLICE(1, 128, 128, 7, 100); + +} // namespace test +} // namespace ops +} // namespace mace diff --git a/mace/ops/softmax.cc b/mace/ops/softmax.cc index cbab37adf5ebe9e0a3195483cecc287be5931bd0..427a29eb850c3a5577c4fd57a5b49e401e255b51 100644 --- a/mace/ops/softmax.cc +++ b/mace/ops/softmax.cc @@ -42,7 +42,8 @@ template <> class SoftmaxOp : public Operation { public: explicit SoftmaxOp(OpConstructContext *context) - : Operation(context) {} + : Operation(context), + use_log_(Operation::GetOptionalArg("use_log", false)) {} MaceStatus Run(OpContext *context) override { MACE_UNUSED(context); @@ -88,9 +89,18 @@ class SoftmaxOp : public Operation { sum = std::max(sum, std::numeric_limits::min()); channel_offset = 0; - for (index_t c = 0; c < class_count; ++c) { - output_ptr[channel_offset] /= sum; - channel_offset += class_size; + if (use_log_) { + for (index_t c = 0; c < class_count; ++c) { + output_ptr[channel_offset] /= sum; + output_ptr[channel_offset] = + std::log(output_ptr[channel_offset]); + channel_offset += class_size; + } + } else { + for (index_t c = 0; c < class_count; ++c) { + output_ptr[channel_offset] /= sum; + channel_offset += class_size; + } } } // k } // b @@ -123,8 +133,15 @@ class SoftmaxOp : public Operation { } sum = std::max(sum, std::numeric_limits::min()); - for (index_t c = 0; c < class_count; ++c) { - output_ptr[c] /= sum; + if (use_log_) { + for (index_t c = 0; c < class_count; ++c) { + output_ptr[c] /= sum; + output_ptr[c] = std::log(output_ptr[c]); + } + } else { + for (index_t c = 0; c < class_count; ++c) { + output_ptr[c] /= sum; + } } } } else { @@ -132,6 +149,9 @@ class SoftmaxOp : public Operation { } return MaceStatus::MACE_SUCCESS; } + + protected: + bool use_log_; }; #ifdef MACE_ENABLE_QUANTIZE @@ -142,10 +162,12 @@ template <> class SoftmaxOp : public Operation { public: explicit SoftmaxOp(OpConstructContext *context) - : Operation(context) {} + : Operation(context), + use_log_(Operation::GetOptionalArg("use_log", false)) {} MaceStatus Run(OpContext *context) override { MACE_UNUSED(context); + MACE_CHECK(!use_log_, "MACE dose not support quantized logsoftmax yet."); const Tensor *input = this->Input(0); Tensor *output = this->Output(0); MACE_RETURN_IF_ERROR(output->ResizeLike(input)); @@ -366,6 +388,9 @@ class SoftmaxOp : public Operation { } return MaceStatus::MACE_SUCCESS; } + + protected: + bool use_log_; }; #endif // MACE_ENABLE_QUANTIZE @@ -375,11 +400,13 @@ class SoftmaxOp : public Operation { public: explicit SoftmaxOp(OpConstructContext *context) : Operation(context) { + bool use_log = ( + Operation::GetOptionalArg("use_log", false)); if (context->device()->gpu_runtime()->UseImageMemory()) { - kernel_ = make_unique>(); + kernel_ = make_unique>(use_log); } else { context->set_output_mem_type(MemoryType::GPU_BUFFER); - kernel_ = make_unique>(); + kernel_ = make_unique>(use_log); } } MaceStatus Run(OpContext *context) override { diff --git a/mace/ops/softmax_test.cc b/mace/ops/softmax_test.cc index ee64820a0f4cb63e68ddb148f8df8c3a85bb332d..ab818ac8d55b5c0b277c41fb0044797666ee4bce 100644 --- a/mace/ops/softmax_test.cc +++ b/mace/ops/softmax_test.cc @@ -12,6 +12,13 @@ // See the License for the specific language governing permissions and // limitations under the License. +// python implementation +// import numpy as np +// x = np.asarray([1., 1., 1., 1.], 'f') +// exp_x = np.exp(x) +// softmax_x = exp_x / np.sum(exp_x) +// log_softmax_x = np.log(softmax_x) + #include "mace/ops/ops_test_util.h" namespace mace { @@ -19,18 +26,27 @@ namespace ops { namespace test { class SoftmaxOpTest : public OpsTestBase {}; +class LogSoftmaxOpTest : public OpsTestBase {}; namespace { template -void Simple() { +void Simple(bool use_log = false) { // Construct graph OpsTestNet net; // Add input data net.AddInputFromArray("Input", {1, 1, 2, 4}, {1, 1, 1, 1, 1, 2, 3, 4}); + + std::vector expected_data(8); + if (use_log) { + expected_data = {-1.3862944, -1.3862944, -1.3862944, -1.3862944, + -3.4401896 , -2.4401896 , -1.4401897 , -0.44018975}; + } else { + expected_data = {0.25, 0.25, 0.25, 0.25, + 0.0320586, 0.08714432, 0.23688282, 0.6439142}; + } auto expected = net.CreateTensor( - {1, 1, 2, 4}, - {0.25, 0.25, 0.25, 0.25, 0.0320586, 0.08714432, 0.23688282, 0.64391426}); + {1, 1, 2, 4}, expected_data); if (D == DeviceType::CPU) { // test 4d softmax @@ -38,6 +54,7 @@ void Simple() { OpDefBuilder("Softmax", "SoftmaxTest") .Input("InputNCHW") .Output("OutputNCHW") + .AddIntArg("use_log", static_cast(use_log)) .Finalize(net.NewOperatorDef()); // Run @@ -52,6 +69,7 @@ void Simple() { OpDefBuilder("Softmax", "SoftmaxTest") .Input("Input2d") .Output("Output") + .AddIntArg("use_log", static_cast(use_log)) .Finalize(net.NewOperatorDef()); // Run @@ -62,6 +80,7 @@ void Simple() { OpDefBuilder("Softmax", "SoftmaxTest") .Input("Input") .Output("Output") + .AddIntArg("use_log", static_cast(use_log)) .Finalize(net.NewOperatorDef()); // Run @@ -77,9 +96,13 @@ void Simple() { TEST_F(SoftmaxOpTest, CPUSimple) { Simple(); } TEST_F(SoftmaxOpTest, OPENCLSimple) { Simple(); } +TEST_F(LogSoftmaxOpTest, CPUSimple) { Simple(true); } +TEST_F(LogSoftmaxOpTest, OPENCLSimple) { Simple(true); } + namespace { template -void Complex(const std::vector &logits_shape) { +void Complex(const std::vector &logits_shape, + bool use_log = false) { // Construct graph OpsTestNet net; // Add input data @@ -91,11 +114,13 @@ void Complex(const std::vector &logits_shape) { OpDefBuilder("Softmax", "SoftmaxTest") .Input("InputNCHW") .Output("OutputNCHW") + .AddIntArg("use_log", static_cast(use_log)) .Finalize(net.NewOperatorDef()); } else { OpDefBuilder("Softmax", "SoftmaxTest") .Input("Input") .Output("Output") + .AddIntArg("use_log", static_cast(use_log)) .Finalize(net.NewOperatorDef()); } // Run on cpu @@ -111,6 +136,7 @@ void Complex(const std::vector &logits_shape) { OpDefBuilder("Softmax", "SoftmaxTest") .Input("Input") .Output("Output") + .AddIntArg("use_log", static_cast(use_log)) .Finalize(net.NewOperatorDef()); // Run on gpu @@ -140,6 +166,26 @@ TEST_F(SoftmaxOpTest, OPENCLAlignedRank2) { Complex({3, 1001}); } +TEST_F(LogSoftmaxOpTest, OPENCLAligned) { +Complex({1, 256, 256, 3}, true); +Complex({1, 128, 128, 16}, true); +} + +TEST_F(LogSoftmaxOpTest, OPENCLMulBatchAligned) { +Complex({5, 64, 64, 3}, true); +Complex({8, 128, 128, 8}, true); +} + +TEST_F(LogSoftmaxOpTest, OPENCLUnAligned) { +Complex({1, 113, 107, 13}, true); +Complex({5, 211, 107, 1}, true); +} + +TEST_F(LogSoftmaxOpTest, OPENCLAlignedRank2) { +Complex({1, 1001}, true); +Complex({3, 1001}, true); +} + namespace { void TestQuantizedSoftmax(const std::vector &input_shape) { diff --git a/mace/ops/splice.cc b/mace/ops/splice.cc index bf1dfe36b41e8c79675f3d75b0578fa2ce76816e..8517b6b831c80397086e0598f8803aeff0be81ce 100644 --- a/mace/ops/splice.cc +++ b/mace/ops/splice.cc @@ -14,16 +14,11 @@ // This Op is for SpliceComponent in Kaldi. // It splices a context window of frames together [over time] -// (copy and append the frame whose time-index in in context_) +// (copy and append the frame whose time-index is in context_) // The context_ values indicate which frame (over time) to splice. -// if context value is less than the first time-index, -// copy and append the first frame's dada, -// when context value is larger than frame's count, -// copy and append the last frame's data. -// i.e., give input data: [[1, 2, 3], [4, 5, 6]], -// with input-dim = 3, frame count = 2, context = [-1, 0, 1] -// Then, the output should be: -// [1, 2, 3, 1, 2, 3, 4, 5, 6], [1, 2, 3, 4, 5, 6, 4, 5, 6] +// It will reduce frames because of left context and right context. +// i.e., give input data with shape {20, 40}, and contexts:{-2, -1, 0, 1, 2}, +// the output shape should be {16, 200} // if const_component_dim_ != 0, const_dim_ will be used to determine which // row of "in" we copy the last part of each row of "out" from (this part is // not subject to splicing, it's assumed constant for each frame of "input". @@ -54,24 +49,34 @@ class SpliceOp : public Operation { const Tensor *input = this->Input(0); MACE_CHECK(context_.size() > 0) << "The context param should not be empty in Splice Op."; + MACE_CHECK(input->dim_size() >= 2) + << "Splice's input's rank should be greater than 2."; Tensor *output = this->Output(0); const std::vector &input_shape = input->shape(); - const index_t frames = - std::accumulate(input->shape().begin(), input->shape().end() - 1, 1, + const index_t batch = + std::accumulate(input->shape().begin(), input->shape().end() - 2, 1, std::multiplies()); - const index_t rank = input->dim_size(); + const index_t chunk = input_shape[rank - 2]; const index_t input_dim = input_shape[rank - 1]; + const index_t input_stride = chunk * input_dim; const index_t num_splice = static_cast(context_.size()); const index_t dim = input_dim - const_dim_; + const index_t left_context = context_[0]; + const index_t right_context = context_[num_splice -1]; + + const index_t out_chunk = chunk - (right_context - left_context); + MACE_CHECK(input_dim > const_dim_, "input dim should be greater than const dim."); const index_t output_dim = dim * num_splice + const_dim_; + const index_t output_stride = out_chunk * output_dim; std::vector output_shape = input->shape(); + output_shape[rank - 2] = out_chunk; output_shape[rank - 1] = output_dim; MACE_RETURN_IF_ERROR(output->Resize(output_shape)); @@ -80,28 +85,32 @@ class SpliceOp : public Operation { const T *input_data = input->data(); T *output_data = output->mutable_data(); -#pragma omp parallel for collapse(2) schedule(runtime) - for (index_t i = 0; i < frames; ++i) { +#pragma omp parallel for collapse(3) schedule(runtime) + for (int b = 0; b < batch; ++b) { + for (index_t i = 0; i < out_chunk; ++i) { for (index_t c = 0; c < num_splice; ++c) { - const index_t offset = - Clamp(context_[c] + i, 0, frames - 1); - T *output_base = output_data + i * output_dim + c * dim; - const T *input_base = input_data + offset * input_dim; + const index_t offset = i + context_[c] - left_context; + T *output_base = + output_data + b * output_stride + i * output_dim + c * dim; + const T *input_base = + input_data + b * input_stride + offset * input_dim; memcpy(output_base, input_base, dim * sizeof(T)); } } + } if (const_dim_ > 0) { const index_t output_offset = output_dim - const_dim_; const index_t input_offset = dim; -#pragma omp parallel for schedule(runtime) - for (index_t i = 0; i < frames; ++i) { - index_t offset = i + context_[0] >= 0 ? i + context_[0] : 0; - T *output_base = output_data + i * output_dim; - const T *input_base = input_data + offset * input_dim; +#pragma omp parallel for collapse(2) schedule(runtime) + for (int b = 0; b < batch; ++b) { + for (index_t i = 0; i < out_chunk; ++i) { + T *output_base = output_data + + b * output_stride + i * output_dim; + const T *input_base = input_data + b * input_stride + i * input_dim; memcpy(output_base + output_offset, input_base + input_offset, const_dim_ * sizeof(T)); + } } } return MaceStatus::MACE_SUCCESS; diff --git a/mace/ops/splice_benchmark.cc b/mace/ops/splice_benchmark.cc index 253808b8385e1526432cfdc3cd5befd98f70736b..2a90493dcf2535e341db49fe848127bf70bd1929 100644 --- a/mace/ops/splice_benchmark.cc +++ b/mace/ops/splice_benchmark.cc @@ -23,15 +23,15 @@ namespace { template void BMSpliceHelper(int iters, const std::vector &input_shape, - const index_t left_context, - const index_t right_context, + const int left_context, + const int right_context, const int const_component_dim) { mace::testing::StopTiming(); // Construct graph OpsTestNet net; - const int num_splice = left_context + right_context + 1; + const index_t num_splice = left_context + right_context + 1; std::vector contexts(num_splice); for (int i = 0; i < num_splice; ++i) { contexts[i] = left_context + i; @@ -44,7 +44,7 @@ void BMSpliceHelper(int iters, GenerateRandomRealTypeData(input_shape, &input_data); net.AddInputFromArray("Input", input_shape, input_data); - OpDefBuilder("Splice", "SpliceTest") + OpDefBuilder("Splice", "SpliceBM") .Input("Input") .Output("Output") .AddIntsArg("context", contexts) @@ -71,7 +71,6 @@ void BMSpliceHelper(int iters, MACE_BM_SPLICE_##N##_##H##_##W##_##L##_##R##_##C##_##TYPE##_##DEVICE( \ int iters) { \ const int64_t tot = static_cast(iters) * N * H * W; \ - mace::testing::MacsProcessed(tot); \ mace::testing::BytesProcessed(tot *(sizeof(TYPE))); \ BMSpliceHelper(iters, {N, H, W}, L, R, C); \ } \ diff --git a/mace/ops/splice_test.cc b/mace/ops/splice_test.cc index 60e1652a394d7d1a7b88c0b1f537ec5fc688d613..b6bc3d32c179a4475f0c58c921138be901ee8c2b 100644 --- a/mace/ops/splice_test.cc +++ b/mace/ops/splice_test.cc @@ -53,14 +53,10 @@ TEST_F(SpliceOpTest, WithoutConstDim) { {1, 7, 2}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14}, {-2, -1, 0, 1, 2}, 0, - {1, 7, 10}, - {1, 2, 1, 2, 1, 2, 3, 4, 5, 6, - 1, 2, 1, 2, 3, 4, 5, 6, 7, 8, - 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, + {1, 3, 10}, + {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, - 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, - 7, 8, 9, 10, 11, 12, 13, 14, 13, 14, - 9, 10, 11, 12, 13, 14, 13, 14, 13, 14}); + 5, 6, 7, 8, 9, 10, 11, 12, 13, 14}); } TEST_F(SpliceOpTest, WithConstDim) { @@ -72,12 +68,8 @@ TEST_F(SpliceOpTest, WithConstDim) { 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14}, {-2, -1, 0, 1, 2}, 7, - {1, 5, 22}, - {1, 2, 3, 1, 2, 3, 1, 2, 3, 2, 3, 4, 3, 4, 5, 4, 5, 6, 7, 8, 9, 10, - 1, 2, 3, 1, 2, 3, 2, 3, 4, 3, 4, 5, 4, 5, 6, 4, 5, 6, 7, 8, 9, 10, - 1, 2, 3, 2, 3, 4, 3, 4, 5, 4, 5, 6, 5, 6, 7, 4, 5, 6, 7, 8, 9, 10, - 2, 3, 4, 3, 4, 5, 4, 5, 6, 5, 6, 7, 5, 6, 7, 5, 6, 7, 8, 9, 10, 11, - 3, 4, 5, 4, 5, 6, 5, 6, 7, 5, 6, 7, 5, 6, 7, 6, 7, 8, 9, 10, 11, 12}); + {1, 1, 22}, + {1, 2, 3, 2, 3, 4, 3, 4, 5, 4, 5, 6, 5, 6, 7, 4, 5, 6, 7, 8, 9, 10}); } } // namespace test } // namespace ops diff --git a/mace/ops/time_offset_benchmark.cc b/mace/ops/time_offset_benchmark.cc deleted file mode 100644 index 82ea9967a9bd95542f012666593e81005cd64c48..0000000000000000000000000000000000000000 --- a/mace/ops/time_offset_benchmark.cc +++ /dev/null @@ -1,78 +0,0 @@ -// Copyright 2018 The MACE Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include -#include - -#include "mace/core/testing/test_benchmark.h" -#include "mace/ops/ops_test_util.h" - -namespace mace { -namespace ops { -namespace test { - -namespace { -template -void TimeOffsetBenchmark(int iters, - std::vector shape, - int offset) { - mace::testing::StopTiming(); - - OpsTestNet net; - - // Add input data - net.AddRandomInput("Input", shape); - - OpDefBuilder("TimeOffset", "TimeOffsetBM") - .Input("Input") - .Output("Output") - .AddIntArg("offset", offset) - .Finalize(net.NewOperatorDef()); - - // Warm-up - for (int i = 0; i < 5; ++i) { - net.RunOp(D); - } - net.Sync(); - - mace::testing::StartTiming(); - while (iters--) { - net.RunOp(D); - } - net.Sync(); -} -} // namespace - -#define MACE_BM_TIMEOFFSET2D_MACRO(H, W, TYPE, DEVICE) \ - static void MACE_BM_TIMEOFFSET2D_##H##_##W##_##TYPE##_##DEVICE(\ - int iters) { \ - const int64_t tot = static_cast(iters) * H * W; \ - mace::testing::BytesProcessed(tot *(sizeof(TYPE))); \ - TimeOffsetBenchmark(iters, {H, W}, 1); \ - } \ - MACE_BENCHMARK(MACE_BM_TIMEOFFSET2D_##H##_##W##_##TYPE##_##DEVICE) \ - -#define MACE_BM_TIMEOFFSET2D(H, W) \ - MACE_BM_TIMEOFFSET2D_MACRO(H, W, float, CPU); - - -MACE_BM_TIMEOFFSET2D(20, 128); -MACE_BM_TIMEOFFSET2D(40, 512); -MACE_BM_TIMEOFFSET2D(1, 1024); -MACE_BM_TIMEOFFSET2D(20, 2048); -MACE_BM_TIMEOFFSET2D(20, 512); - -} // namespace test -} // namespace ops -} // namespace mace diff --git a/mace/ops/time_offset_test.cc b/mace/ops/time_offset_test.cc deleted file mode 100644 index b32b8c52acf3b8af715dac74f92d4a87efe1a102..0000000000000000000000000000000000000000 --- a/mace/ops/time_offset_test.cc +++ /dev/null @@ -1,125 +0,0 @@ -// Copyright 2018 The MACE Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "mace/ops/ops_test_util.h" - -namespace mace { -namespace ops { -namespace test { - -class TimeOffsetOpTest : public OpsTestBase {}; - -namespace { -template -void TestTimeOffset(const std::vector &input_shape, - const std::vector &input, - const int offset, - const std::vector &output) { - OpsTestNet net; - net.AddInputFromArray(MakeString("Input"), - input_shape, - input); - - OpDefBuilder("TimeOffset", "TimeOffsetTest") - .Input("Input") - .Output("Output") - .AddIntArg("offset", offset) - .Finalize(net.NewOperatorDef()); - - net.RunOp(); - - net.AddInputFromArray("ExpectedOutput", input_shape, output); - ExpectTensorNear(*net.GetOutput("ExpectedOutput"), - *net.GetOutput("Output")); -} -} // namespace - -TEST_F(TimeOffsetOpTest, Simple2Dim) { - TestTimeOffset( - {3, 5}, - {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, - -2, - {1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 1, 2, 3, 4, 5}); - - TestTimeOffset( - {3, 5}, - {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, - -1, - {1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10}); - - TestTimeOffset( - {3, 5}, - {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, - 0, - {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}); - - TestTimeOffset( - {3, 5}, - {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, - 1, - {6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 11, 12, 13, 14, 15}); - - TestTimeOffset( - {3, 5}, - {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, - 2, - {11, 12, 13, 14, 15, 11, 12, 13, 14, 15, 11, 12, 13, 14, 15}); -} - - -TEST_F(TimeOffsetOpTest, Simple3Dim) { - TestTimeOffset( - {2, 3, 5}, - {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, - 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, - -2, - {1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 1, 2, 3, 4, 5, - 1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 1, 2, 3, 4, 5}); - - TestTimeOffset( - {2, 3, 5}, - {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, - 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, - -1, - {1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, - 1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10}); - - TestTimeOffset( - {2, 3, 5}, - {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, - 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, - 0, - {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, - 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}); - - TestTimeOffset( - {2, 3, 5}, - {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, - 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, - 1, - {6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 11, 12, 13, 14, 15, - 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 11, 12, 13, 14, 15}); - - TestTimeOffset( - {2, 3, 5}, - {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, - 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, - 2, - {11, 12, 13, 14, 15, 11, 12, 13, 14, 15, 11, 12, 13, 14, 15, - 11, 12, 13, 14, 15, 11, 12, 13, 14, 15, 11, 12, 13, 14, 15}); -} - -} // namespace test -} // namespace ops -} // namespace mace diff --git a/mace/python/tools/converter_tool/base_converter.py b/mace/python/tools/converter_tool/base_converter.py index 7fc877d662a90bc4d6030daab3843b27cb801f80..bb1cb88a4f6ea9207e8defa058abce65140a11bc 100644 --- a/mace/python/tools/converter_tool/base_converter.py +++ b/mace/python/tools/converter_tool/base_converter.py @@ -102,7 +102,6 @@ class FrameworkType(Enum): MaceSupportedOps = [ 'Activation', 'AddN', - 'Affine', 'ArgMax', 'BatchNorm', 'BatchToSpaceND', @@ -126,10 +125,12 @@ MaceSupportedOps = [ 'InferConv2dShape', 'LocalResponseNorm', 'LSTMCell', - # 'LstmNonlinear', + 'LstmNonlinear', + 'DynamicLSTM', 'MatMul', 'OneHot', 'Pad', + 'PadContext', 'PNorm', 'Pooling', 'PriorBox', @@ -156,7 +157,6 @@ MaceSupportedOps = [ 'SqrDiffMean', 'SumGroup', 'TargetRMSNorm', - 'TimeOffset', 'Transpose', 'WinogradInverseTransform', 'WinogradTransform', diff --git a/mace/python/tools/converter_tool/onnx_converter.py b/mace/python/tools/converter_tool/onnx_converter.py index 68f781a23dfc4fe5d09163b59422be15fec31f87..805cbd272731e1548a11dc23a284b305f2a2ee14 100644 --- a/mace/python/tools/converter_tool/onnx_converter.py +++ b/mace/python/tools/converter_tool/onnx_converter.py @@ -72,6 +72,7 @@ OnnxSupportedOps = [ 'DimRange', 'Div', 'Dropout', + 'DynamicLstmCell', 'Elu', 'Equal', # 'Exp', @@ -90,16 +91,16 @@ OnnxSupportedOps = [ # 'Hardmax', 'Identity', # 'If', - 'IfDefined', + # 'IfDefined', 'ImageScaler', # 'InstanceNormalization', # 'LRN', 'LSTM', - # 'LstmNonlinear', + 'LstmNonlinear', 'LeakyRelu', # 'Less', # 'Log', - # 'LogSoftmax', + 'LogSoftmax', # 'Loop', # 'LpNormalization', # 'LpPool', @@ -120,6 +121,7 @@ OnnxSupportedOps = [ # 'Or', 'PRelu', # 'Pad', + 'PadContext', 'Padding', 'PNorm', 'Pow', @@ -331,12 +333,11 @@ class OnnxConverter(base_converter.ConverterInterface): OnnxOpType.GlobalAveragePool.name: self.convert_reduce, OnnxOpType.GlobalMaxPool.name: self.convert_reduce, OnnxOpType.Identity.name: self.convert_identity, - OnnxOpType.IfDefined.name: self.convert_identity, OnnxOpType.ImageScaler.name: self.convert_imagescaler, OnnxOpType.LeakyRelu.name: self.convert_activation, - # OnnxOpType.LogSoftmax.name: self.convert_softmax, - OnnxOpType.LSTM.name: self.convert_lstm, - # OnnxOpType.LstmNonlinear.name: self.convert_lstm_nonlinear, + OnnxOpType.LogSoftmax.name: self.convert_softmax, + OnnxOpType.LstmNonlinear.name: self.convert_lstm_nonlinear, + OnnxOpType.DynamicLstmCell.name: self.convert_dynamic_lstm, OnnxOpType.Max.name: self.convert_eltwise, OnnxOpType.MaxPool.name: self.convert_pooling, OnnxOpType.MatMul.name: self.convert_matmul, @@ -344,7 +345,8 @@ class OnnxConverter(base_converter.ConverterInterface): OnnxOpType.Mul.name: self.convert_eltwise, OnnxOpType.Neg.name: self.convert_eltwise, OnnxOpType.Normalize: self.convert_normalize, - OnnxOpType.Offset.name: self.convert_timeoffset, + OnnxOpType.Offset.name: self.convert_identity, + OnnxOpType.PadContext.name: self.convert_pad_context, OnnxOpType.Padding.name: self.convert_identity, OnnxOpType.PNorm.name: self.convert_pnorm, OnnxOpType.Pow.name: self.convert_eltwise, @@ -642,7 +644,7 @@ class OnnxConverter(base_converter.ConverterInterface): mace_check(axis_value == 1 or axis_value == -3, "only support concat at channel dimension") elif node.op_type == OnnxOpType.Append.name: - axis_value = 2 + axis_value = 1 axis_arg = op.arg.add() axis_arg.name = MaceKeyword.mace_axis_str axis_arg.i = 4 + axis_value if axis_value < 0 else axis_value @@ -758,14 +760,69 @@ class OnnxConverter(base_converter.ConverterInterface): offset = node.attrs['offset'] starts_arg = op.arg.add() starts_arg.name = 'starts' - starts_arg.ints.append(offset) + starts_arg.ints.extend([offset]) output_dim = node.attrs['output_dim'] ends_arg = op.arg.add() - ends_arg.name = 'output_dim' - ends_arg.ints.append(output_dim) + ends_arg.name = 'ends' + ends_arg.ints.extend([output_dim + offset]) axes_arg = op.arg.add() axes_arg.name = 'axes' - axes_arg.ints.append(-1) + axes_arg.ints.extend([-1]) + + def convert_dynamic_lstm(self, node): + op = self.convert_general_op(node) + op.type = MaceOp.DynamicLSTM.name + + if 'delay_a' in node.attrs: + prev_out_delay = node.attrs['delay_a'] + mace_check(prev_out_delay < 0, + "dynamic's prev_out_delay should <= 0.") + prev_out_delay_arg = op.arg.add() + prev_out_delay_arg.name = 'prev_out_delay' + prev_out_delay_arg.i = prev_out_delay + if 'delay_b' in node.attrs: + prev_cell_delay = node.attrs['delay_b'] + mace_check(prev_cell_delay < 0, + "dynamic's prev_cell_delay should < 0.") + prev_cell_delay_arg = op.arg.add() + prev_cell_delay_arg.name = 'prev_cell_delay' + prev_cell_delay_arg.i = prev_cell_delay + if 'prev_out_offset' in node.attrs: + prev_out_offset = node.attrs['prev_out_offset'] + mace_check(pre_out_offset >= 0, + "dynamic's prev_out_offset should >= 0.") + prev_out_offset_arg = op.arg.add() + prev_out_offset_arg.name = 'prev_out_offset' + prev_out_offset_arg.i = prev_out_offset + if 'prev_a_dim' in node.attrs: + prev_out_dim = node.attrs['prev_a_dim'] + mace_check(prev_out_dim > 0, + "dynamic's prev_out_dim should > 0.") + prev_out_dim_arg = op.arg.add() + prev_out_dim_arg.name = 'prev_out_dim' + prev_out_dim_arg.i = prev_out_dim + if 'prev_b_dim' in node.attrs: + prev_cell_dim = node.attrs['prev_b_dim'] + mace_check(prev_cell_dim > 0, + "dynamic's prev_cell_dim should > 0.") + prev_cell_dim_arg = op.arg.add() + prev_cell_dim_arg.name = 'prev_cell_dim' + prev_cell_dim_arg.i = prev_cell_dim + if 'bias_a' in node.attrs: + bias_a = node.attrs['bias_a'] + bias_a_arg = op.arg.add() + bias_a_arg.name = 'bias_a' + bias_a_arg.i = bias_a + if 'bias_b' in node.attrs: + bias_b = node.attrs['bias_b'] + bias_b_arg = op.arg.add() + bias_b_arg.name = 'bias_b' + bias_b_arg.i = bias_b + if 'scale' in node.attrs: + scale = node.attrs['scale'] + scale_arg = op.arg.add() + scale_arg.name = 'scale' + scale_arg.f = scale def convert_eltwise(self, node): op = self.convert_general_op(node) @@ -925,6 +982,18 @@ class OnnxConverter(base_converter.ConverterInterface): op = self.convert_general_op(node) op.type = MaceOp.BatchNorm.name + def convert_pad_context(self, node): + op = self.convert_general_op(node) + op.type = MaceOp.PadContext.name + if 'left_context' in node.attrs: + left_context_arg = op.arg.add() + left_context_arg.name = 'left_context' + left_context_arg.i = node.attrs['left_context'] + if 'right_context' in node.attrs: + right_context_arg = op.arg.add() + right_context_arg.name = 'right_context' + right_context_arg.i = node.attrs['right_context'] + def convert_pnorm(self, node): op = self.convert_general_op(node) op.type = MaceOp.PNorm.name @@ -1010,10 +1079,10 @@ class OnnxConverter(base_converter.ConverterInterface): op = self.convert_general_op(node) op.type = MaceOp.Softmax.name # TODO: add logsoftmax in softmax op - # if node.op_type == OnnxOpType.LogSoftmax.name: - # use_log_arg = op.arg.add() - # use_log_arg.name = 'use_log' - # use_log_arg.i = 1 + if node.op_type == OnnxOpType.LogSoftmax.name: + use_log_arg = op.arg.add() + use_log_arg.name = 'use_log' + use_log_arg.i = 1 def convert_splice(self, node): op = self.convert_general_op(node) @@ -1104,6 +1173,11 @@ class OnnxConverter(base_converter.ConverterInterface): else: op.type = MaceOp.TimeOffset.name + chunk_size = node.attrs['chunk_size'] + chunk_size_arg = op.arg.add() + chunk_size_arg.name = 'chunk_size' + chunk_size_arg.i = chunk_size + offset_arg = op.arg.add() offset_arg.name = 'offset' offset_arg.i = offset diff --git a/mace/python/tools/converter_tool/transformer.py b/mace/python/tools/converter_tool/transformer.py index 1083e23545767725e2f4e0d9c394d790fd5d0dd3..a097b10aba2691491ccf80a3d8952f4ff1500151 100644 --- a/mace/python/tools/converter_tool/transformer.py +++ b/mace/python/tools/converter_tool/transformer.py @@ -1143,6 +1143,7 @@ class Transformer(base_converter.ConverterInterface): filter.float_data[:] = filter_data.flat filter.dims[:] = filter_data.shape arg.i = 1 + six.print_('transpose matmul weight') def transpose_filters(self): net = self._model @@ -1192,7 +1193,6 @@ class Transformer(base_converter.ConverterInterface): mace_check(filter_format == DataFormat.HWIO, "HEXAGON only support HWIO/HWIM filter format.") else: - print("Transpose filters to OIHW/MIHW") # transpose filter to OIHW/MIHW for tensorflow (HWIO/HWIM) if filter_format == DataFormat.HWIO: for op in net.op: @@ -1201,6 +1201,7 @@ class Transformer(base_converter.ConverterInterface): or op.type == MaceOp.DepthwiseConv2d.name) \ and op.input[1] in self._consts \ and op.input[1] not in transposed_filter: + print("Transpose Conv2D/Deconv2D filters to OIHW/MIHW") filter = self._consts[op.input[1]] filter_data = np.array(filter.float_data).reshape( filter.dims) @@ -1208,9 +1209,13 @@ class Transformer(base_converter.ConverterInterface): filter.float_data[:] = filter_data.flat filter.dims[:] = filter_data.shape transposed_filter.add(op.input[1]) - if (op.type == MaceOp.MatMul.name - and (ConverterUtil.get_arg(op, MaceKeyword.mace_winograd_filter_transformed) is not None) # noqa + if (op.type == MaceOp.MatMul.name and + (ConverterUtil.get_arg( + op, + MaceKeyword.mace_winograd_filter_transformed) + is not None) # noqa and op.input[1] not in transposed_filter): + print("Transpose Winograd filters to OIHW/MIHW") filter = self._consts[op.input[0]] filter_data = np.array(filter.float_data).reshape( filter.dims) @@ -1222,6 +1227,8 @@ class Transformer(base_converter.ConverterInterface): and op.input[1] not in transposed_filter: weight = self._consts[op.input[1]] if len(weight.dims) == 4: + print("Transpose FullyConnected filters to" + " OIHW/MIHW") weight_data = np.array(weight.float_data).reshape( weight.dims) weight_data = weight_data.transpose(3, 2, 0, 1) diff --git a/mace/utils/math.h b/mace/utils/math.h index 0293806c66667d55439b6802e1a8ec3943c1635e..2a806ab433757e3e3c48a2d2ea57dabaad50da32 100644 --- a/mace/utils/math.h +++ b/mace/utils/math.h @@ -60,25 +60,22 @@ inline Integer Clamp(Integer in, Integer low, Integer high) { return std::max(low, std::min(in, high)); } -template -inline T ScalarSigmoid(T in) { - if (in > static_cast(0)) { - return static_cast(1) / (static_cast(1) + std::exp(-in)); +inline float ScalarSigmoid(float in) { + if (in > 0) { + return 1 / (1 + std::exp(-in)); } else { - T x = std::exp(in); - return x / (x + static_cast(1)); + float x = std::exp(in); + return x / (x + 1.f); } } -template -inline T ScalarTanh(T in) { - if (in > static_cast(0)) { - T inv_expa = std::exp(-in); - return -static_cast(1) + - static_cast(2) / (static_cast(1) + inv_expa * inv_expa); +inline float ScalarTanh(float in) { + if (in > 0) { + float x = std::exp(-in); + return -1.f + 2.f / (1.f + x * x); } else { - T x = std::exp(in); - return x / (x + static_cast(1)); + float x = std::exp(in); + return 1.f - 2.f / (1.f + x * x); } }