提交 b3526d30 编写于 作者: 刘托

Merge branch 'fix_typo' into 'master'

fix kaldi lstm typo, update docs, add block_dim limit check for TargetRmsOp

See merge request !1112
...@@ -20,4 +20,4 @@ After fusing: ...@@ -20,4 +20,4 @@ After fusing:
:align: center :align: center
For more details about LSTMNonlinear in Kaldi, For more details about LSTMNonlinear in Kaldi,
please refer to [LstmNonlinearComponent](http://kaldi-asr.org/doc/nnet-simple-component_8h_source.html#l02164) please refer to [LstmNonlinearityComponent](http://kaldi-asr.org/doc/nnet-combined-component_8h_source.html#l00255)
\ No newline at end of file
...@@ -12,9 +12,9 @@ ...@@ -12,9 +12,9 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
// This Op is for Fused-LstmNonlinearComponent // This Op is for Fused-LstmNonlinearityComponent
// with prev cell states as inputs in Kaldi. // with prev cell states as inputs in Kaldi.
// http://kaldi-asr.org/doc/nnet-simple-component_8h_source.html#l02164 // http://kaldi-asr.org/doc/nnet-combined-component_8h_source.html#l00255
// More details are in docs/development/dynamic_lstm.md // More details are in docs/development/dynamic_lstm.md
#include <functional> #include <functional>
......
...@@ -12,8 +12,8 @@ ...@@ -12,8 +12,8 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
// This Op is for LstmNonlinearComponent in Kaldi. // This Op is for LstmNonlinearityComponent in Kaldi.
// http://kaldi-asr.org/doc/nnet-simple-component_8h_source.html#l02164 // http://kaldi-asr.org/doc/nnet-combined-component_8h_source.html#l00255
#include <functional> #include <functional>
#include <memory> #include <memory>
......
...@@ -104,6 +104,7 @@ class TargetRMSNormOp<DeviceType::CPU, T> : public Operation { ...@@ -104,6 +104,7 @@ class TargetRMSNormOp<DeviceType::CPU, T> : public Operation {
std::accumulate(input_shape.begin(), input_shape.end() - 1, 1, std::accumulate(input_shape.begin(), input_shape.end() - 1, 1,
std::multiplies<index_t>()); std::multiplies<index_t>());
if (block_dim_ == 0) block_dim_ = static_cast<int>(input_dim); if (block_dim_ == 0) block_dim_ = static_cast<int>(input_dim);
MACE_CHECK(input_dim % block_dim_ == 0, "block_dim must divide input_dim!");
const index_t output_dim = add_log_stddev_ ? const index_t output_dim = add_log_stddev_ ?
input_dim + (input_dim / block_dim_) : input_dim; input_dim + (input_dim / block_dim_) : input_dim;
std::vector<index_t> output_shape = input->shape(); std::vector<index_t> output_shape = input->shape();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册