• G
    [cherry-pick] Incorporate cudnn_lstm into LSTM api (#27217) (#28023) · 3f565903
    Guo Sheng 提交于
    * Incorporate cudnn_lstm into LSTM api (#27217)
    
    * Incorporate cudnn_lstm into LSTM api.
    test=develop
    
    * Make coalesce_tensor support alignment optionally.
    test=develop
    
    * Reorganize RNN apis. test=develop
    
    * Fix cudnn rnn layout conversion.
    test=develop
    
    * Add sequence_length support for RNN cudnn implement.
    Add optional init_h and init_c gradient for cudnn_lstm_op.
    test=develop
    
    * Use create_parameter for rnn cudnn impl.
    test=develop
    
    * Move `self._flat_weight = self.create_parameter()` in RNNBase to main_program.
    test=develop
    
    * Update RNN api unittest to use set_device.
    test=develop
    
    * Fix set_place for unit tests of RNN apis.
    test=develop
    
    * Fix use_align in coalesce_tensor_op.
    test=develop
    
    * Adjust RNN apis arguments according to comments.
    test=develop
    
    * Polish documents for SimpleRNN apis.
    test=develop
    
    * Refine random seed in cudnn_lstm_op.
    Expose rnn params from sublayers to RNN.
    test=develop
    
    * Fix RNN saving for jit.save.
    Refine cudnn_lstm dropout behavior.
    test=develop
    
    * Fix doc of GRU. test=develop
    
    * Use ShareDataWith to avoid copying for cudnn_lstm_op test.
    test=develop
    
    * Remove updates on cudnn_lstm temporarily.
    test=develop
    
    * Use ShareDataWith to avoid copying for cudnn_lstm_op test.
    test=develop
    
    * Refine random seed in cudnn_lstm_op.
    test=develop
    
    * Fix test_lstm by adjust ConcreteProgram buffer getter.
    test=develop
    
    * Use create_parameter instead of create_var for rnn._flat_weight for static graph usage.
    test=develop
    
    * Remove W input for cudnn_lstm to pass unused_var_check.
    test=develop
    
    * Add test_predict for RNN unit tests coverage.
    test=develop
    
    * Fix code style of rnn.
    test=develop
    
    * Fix F.rnn usage in rnn.py.
    test=develop
    
    * Fix test_lstm unittest failed and Add more unittest (#28029)
    
    * fix test_lstm unittest failed
    
    * add more unittest
    
    * modify cmakelist
    
    * fix judgement
    Co-authored-by: NAurelius84 <zhangliujie@baidu.com>
    3f565903
test_rnn_nets.py 12.1 KB