From 4e8a3b6529786db81693b08d43950343b5289dae Mon Sep 17 00:00:00 2001 From: yejianwu Date: Wed, 22 May 2019 10:47:05 +0800 Subject: [PATCH] fix kaldi lstm typo, update docs, add block_dim limit check for TargetRmsOp --- docs/development/dynamic_lstm.rst | 2 +- mace/ops/dynamic_lstm.cc | 4 ++-- mace/ops/lstm_nonlinear.cc | 4 ++-- mace/ops/target_rms_norm.cc | 1 + 4 files changed, 6 insertions(+), 5 deletions(-) diff --git a/docs/development/dynamic_lstm.rst b/docs/development/dynamic_lstm.rst index b01bdbc7..30cb8c13 100644 --- a/docs/development/dynamic_lstm.rst +++ b/docs/development/dynamic_lstm.rst @@ -20,4 +20,4 @@ After fusing: :align: center For more details about LSTMNonlinear in Kaldi, -please refer to [LstmNonlinearComponent](http://kaldi-asr.org/doc/nnet-simple-component_8h_source.html#l02164) \ No newline at end of file +please refer to [LstmNonlinearityComponent](http://kaldi-asr.org/doc/nnet-combined-component_8h_source.html#l00255) diff --git a/mace/ops/dynamic_lstm.cc b/mace/ops/dynamic_lstm.cc index 36d24ada..9ef15ccc 100644 --- a/mace/ops/dynamic_lstm.cc +++ b/mace/ops/dynamic_lstm.cc @@ -12,9 +12,9 @@ // See the License for the specific language governing permissions and // limitations under the License. -// This Op is for Fused-LstmNonlinearComponent +// This Op is for Fused-LstmNonlinearityComponent // with prev cell states as inputs in Kaldi. -// http://kaldi-asr.org/doc/nnet-simple-component_8h_source.html#l02164 +// http://kaldi-asr.org/doc/nnet-combined-component_8h_source.html#l00255 // More details are in docs/development/dynamic_lstm.md #include diff --git a/mace/ops/lstm_nonlinear.cc b/mace/ops/lstm_nonlinear.cc index bf28c79e..fbf92c16 100644 --- a/mace/ops/lstm_nonlinear.cc +++ b/mace/ops/lstm_nonlinear.cc @@ -12,8 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. -// This Op is for LstmNonlinearComponent in Kaldi. -// http://kaldi-asr.org/doc/nnet-simple-component_8h_source.html#l02164 +// This Op is for LstmNonlinearityComponent in Kaldi. +// http://kaldi-asr.org/doc/nnet-combined-component_8h_source.html#l00255 #include #include diff --git a/mace/ops/target_rms_norm.cc b/mace/ops/target_rms_norm.cc index 6caf1ce5..eab76620 100644 --- a/mace/ops/target_rms_norm.cc +++ b/mace/ops/target_rms_norm.cc @@ -104,6 +104,7 @@ class TargetRMSNormOp : public Operation { std::accumulate(input_shape.begin(), input_shape.end() - 1, 1, std::multiplies()); if (block_dim_ == 0) block_dim_ = static_cast(input_dim); + MACE_CHECK(input_dim % block_dim_ == 0, "block_dim must divide input_dim!"); const index_t output_dim = add_log_stddev_ ? input_dim + (input_dim / block_dim_) : input_dim; std::vector output_shape = input->shape(); -- GitLab