diff --git a/docs/development/dynamic_lstm.rst b/docs/development/dynamic_lstm.rst index b01bdbc79d3bba56a8d0a821726e5abeccffd8ca..30cb8c1347c9f66af013aac71ccc908982909a27 100644 --- a/docs/development/dynamic_lstm.rst +++ b/docs/development/dynamic_lstm.rst @@ -20,4 +20,4 @@ After fusing: :align: center For more details about LSTMNonlinear in Kaldi, -please refer to [LstmNonlinearComponent](http://kaldi-asr.org/doc/nnet-simple-component_8h_source.html#l02164) \ No newline at end of file +please refer to [LstmNonlinearityComponent](http://kaldi-asr.org/doc/nnet-combined-component_8h_source.html#l00255) diff --git a/mace/ops/dynamic_lstm.cc b/mace/ops/dynamic_lstm.cc index 36d24ada47295095f7ce1841113a908749c21bf7..9ef15cccdd0005ff1c2621820137c31045307129 100644 --- a/mace/ops/dynamic_lstm.cc +++ b/mace/ops/dynamic_lstm.cc @@ -12,9 +12,9 @@ // See the License for the specific language governing permissions and // limitations under the License. -// This Op is for Fused-LstmNonlinearComponent +// This Op is for Fused-LstmNonlinearityComponent // with prev cell states as inputs in Kaldi. -// http://kaldi-asr.org/doc/nnet-simple-component_8h_source.html#l02164 +// http://kaldi-asr.org/doc/nnet-combined-component_8h_source.html#l00255 // More details are in docs/development/dynamic_lstm.md #include diff --git a/mace/ops/lstm_nonlinear.cc b/mace/ops/lstm_nonlinear.cc index bf28c79e43f79307e074922f0c94ebf311901f25..fbf92c16e4361623d41dfbb50e704a4d8a81021e 100644 --- a/mace/ops/lstm_nonlinear.cc +++ b/mace/ops/lstm_nonlinear.cc @@ -12,8 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. -// This Op is for LstmNonlinearComponent in Kaldi. -// http://kaldi-asr.org/doc/nnet-simple-component_8h_source.html#l02164 +// This Op is for LstmNonlinearityComponent in Kaldi. +// http://kaldi-asr.org/doc/nnet-combined-component_8h_source.html#l00255 #include #include diff --git a/mace/ops/target_rms_norm.cc b/mace/ops/target_rms_norm.cc index 6caf1ce520891d3941f965e0c69d0f3e565eb843..eab76620bc07518f889ef4cc2c73e8e2c24076f0 100644 --- a/mace/ops/target_rms_norm.cc +++ b/mace/ops/target_rms_norm.cc @@ -104,6 +104,7 @@ class TargetRMSNormOp : public Operation { std::accumulate(input_shape.begin(), input_shape.end() - 1, 1, std::multiplies()); if (block_dim_ == 0) block_dim_ = static_cast(input_dim); + MACE_CHECK(input_dim % block_dim_ == 0, "block_dim must divide input_dim!"); const index_t output_dim = add_log_stddev_ ? input_dim + (input_dim / block_dim_) : input_dim; std::vector output_shape = input->shape();