提交 4e8a3b65 编写于 作者: Y yejianwu

fix kaldi lstm typo, update docs, add block_dim limit check for TargetRmsOp

上级 c498a94c
......@@ -20,4 +20,4 @@ After fusing:
:align: center
For more details about LSTMNonlinear in Kaldi,
please refer to [LstmNonlinearComponent](http://kaldi-asr.org/doc/nnet-simple-component_8h_source.html#l02164)
\ No newline at end of file
please refer to [LstmNonlinearityComponent](http://kaldi-asr.org/doc/nnet-combined-component_8h_source.html#l00255)
......@@ -12,9 +12,9 @@
// See the License for the specific language governing permissions and
// limitations under the License.
// This Op is for Fused-LstmNonlinearComponent
// This Op is for Fused-LstmNonlinearityComponent
// with prev cell states as inputs in Kaldi.
// http://kaldi-asr.org/doc/nnet-simple-component_8h_source.html#l02164
// http://kaldi-asr.org/doc/nnet-combined-component_8h_source.html#l00255
// More details are in docs/development/dynamic_lstm.md
#include <functional>
......
......@@ -12,8 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
// This Op is for LstmNonlinearComponent in Kaldi.
// http://kaldi-asr.org/doc/nnet-simple-component_8h_source.html#l02164
// This Op is for LstmNonlinearityComponent in Kaldi.
// http://kaldi-asr.org/doc/nnet-combined-component_8h_source.html#l00255
#include <functional>
#include <memory>
......
......@@ -104,6 +104,7 @@ class TargetRMSNormOp<DeviceType::CPU, T> : public Operation {
std::accumulate(input_shape.begin(), input_shape.end() - 1, 1,
std::multiplies<index_t>());
if (block_dim_ == 0) block_dim_ = static_cast<int>(input_dim);
MACE_CHECK(input_dim % block_dim_ == 0, "block_dim must divide input_dim!");
const index_t output_dim = add_log_stddev_ ?
input_dim + (input_dim / block_dim_) : input_dim;
std::vector<index_t> output_shape = input->shape();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册