• G
    Incorporate cudnn_lstm into LSTM api (#27217) · fa9d3fa5
    Guo Sheng 提交于
    * Incorporate cudnn_lstm into LSTM api.
    test=develop
    
    * Make coalesce_tensor support alignment optionally.
    test=develop
    
    * Reorganize RNN apis. test=develop
    
    * Fix cudnn rnn layout conversion.
    test=develop
    
    * Add sequence_length support for RNN cudnn implement.
    Add optional init_h and init_c gradient for cudnn_lstm_op.
    test=develop
    
    * Use create_parameter for rnn cudnn impl.
    test=develop
    
    * Move `self._flat_weight = self.create_parameter()` in RNNBase to main_program.
    test=develop
    
    * Update RNN api unittest to use set_device.
    test=develop
    
    * Fix set_place for unit tests of RNN apis.
    test=develop
    
    * Fix use_align in coalesce_tensor_op.
    test=develop
    
    * Adjust RNN apis arguments according to comments.
    test=develop
    
    * Polish documents for SimpleRNN apis.
    test=develop
    
    * Refine random seed in cudnn_lstm_op.
    Expose rnn params from sublayers to RNN.
    test=develop
    
    * Fix RNN saving for jit.save.
    Refine cudnn_lstm dropout behavior.
    test=develop
    
    * Fix doc of GRU. test=develop
    
    * Use ShareDataWith to avoid copying for cudnn_lstm_op test.
    test=develop
    
    * Remove updates on cudnn_lstm temporarily.
    test=develop
    
    * Use ShareDataWith to avoid copying for cudnn_lstm_op test.
    test=develop
    
    * Refine random seed in cudnn_lstm_op.
    test=develop
    
    * Fix test_lstm by adjust ConcreteProgram buffer getter.
    test=develop
    
    * Use create_parameter instead of create_var for rnn._flat_weight for static graph usage.
    test=develop
    
    * Remove W input for cudnn_lstm to pass unused_var_check.
    test=develop
    
    * Add test_predict for RNN unit tests coverage.
    test=develop
    
    * Fix code style of rnn.
    test=develop
    
    * Fix F.rnn usage in rnn.py.
    test=develop
    fa9d3fa5
cudnn_lstm_op.cu.cc 18.9 KB