diff --git a/paddle/operators/recurrent_op.cc b/paddle/operators/recurrent_op.cc index 80de229c333f645fb3098b97fa076c6b77bb7ca9..b9fba3e1354eda23654f34607f77c1106d07b55e 100644 --- a/paddle/operators/recurrent_op.cc +++ b/paddle/operators/recurrent_op.cc @@ -28,6 +28,29 @@ using Variable = framework::Variable; using Tensor = framework::Tensor; using LoDTensor = framework::LoDTensor; +void RecurrentAlgorithm::InferShape(const Scope& scope) const { + auto* input0 = scope.FindVar(arg_->inlinks[0]); + PADDLE_ENFORCE_NOT_NULL(input0); + seq_len_ = input0->GetMutable()->dims()[0]; + PADDLE_ENFORCE_GT(seq_len_, 0); + + CreateScopes(scope); + auto& step_scopes = GetStepScopes(scope); + rnn::SegmentInputs(step_scopes, arg_->inlinks, seq_len_, + true /*infer_shape_mode*/); + InitMemories(step_scopes[0], true /*infer_shape_mode*/); + + for (size_t i = 0; i < seq_len_; i++) { + if (i > 0) { + rnn::LinkMemories(step_scopes, arg_->memories, i, -1, + true /*infer_shape_mode*/); + } + (*stepnet_)->InferShape(*step_scopes[i]); + } + rnn::ConcatOutputs(step_scopes, arg_->outlinks, seq_len_, + true /*infer_shape_mode*/); +} + void RecurrentAlgorithm::Run(const Scope& scope, const platform::DeviceContext& dev_ctx) const { auto step_scopes = GetStepScopes(scope); @@ -179,6 +202,24 @@ void RecurrentGradientAlgorithm::LinkBootMemoryGradients( } } +void RecurrentGradientAlgorithm::InferShape(const Scope& scope) const { + seq_len_ = + scope.FindVar(arg_->inlinks[0])->GetMutable()->dims()[0]; + auto step_scopes = GetStepScopes(scope); + rnn::SegmentInputs(step_scopes, arg_->inlinks, seq_len_, + true /*infer_shape_mode*/); + for (int step_id = seq_len_ - 1; step_id >= 0; --step_id) { + if (static_cast(step_id) != seq_len_ - 1) { + rnn::LinkMemories(step_scopes, arg_->memories, step_id, 1, + true /*infer_shape_mode*/); + } + (*stepnet_)->InferShape(*step_scopes[step_id]); + } + rnn::ConcatOutputs(step_scopes, arg_->outlinks, seq_len_, + true /*infer_shape_mode*/); + LinkBootMemoryGradients(step_scopes[0], true /*infer_shape_mode*/); +} + RecurrentGradientOp::RecurrentGradientOp( const std::string& type, const framework::VariableNameMap& inputs, const framework::VariableNameMap& outputs, diff --git a/paddle/operators/recurrent_op.h b/paddle/operators/recurrent_op.h index c6b9a5533eece9057449b5c875ddcb3cefe716f0..18f8c53e18a325624987ac0c8938ee5fc036e55a 100644 --- a/paddle/operators/recurrent_op.h +++ b/paddle/operators/recurrent_op.h @@ -41,6 +41,11 @@ class RecurrentAlgorithm { stepnet_ = stepnet; } + /** + * InferShape must be called before Run. + */ + void InferShape(const framework::Scope& scope) const; + protected: /* * The step scopes will be stored in the father scope as a variable. @@ -89,6 +94,11 @@ class RecurrentGradientAlgorithm { void LinkBootMemoryGradients(framework::Scope* step_scopes, bool infer_shape_mode) const; + /** + * InferShape must be called before Run. + */ + void InferShape(const framework::Scope& scope) const; + protected: inline const std::vector& GetStepScopes( const framework::Scope& scope) const { @@ -123,8 +133,13 @@ class RecurrentOp : public framework::OperatorBase { void set_stepnet(std::unique_ptr net) { stepnet_ = std::move(net); } + const OperatorBase& stepnet() const { return *stepnet_; } + void InferShape(const framework::Scope& scope) const { + alg_.InferShape(scope); + } + static const rnn::ArgumentName kArgName; private: @@ -147,6 +162,10 @@ class RecurrentGradientOp : public framework::OperatorBase { PADDLE_THROW("Not Implemented"); } + void InferShape(const framework::Scope& scope) const { + alg_.InferShape(scope); + } + void Run(const framework::Scope& scope, const platform::DeviceContext& dev_ctx) const override { alg_.Run(scope, dev_ctx); diff --git a/python/paddle/v2/framework/tests/test_recurrent_op.py b/python/paddle/v2/framework/tests/test_recurrent_op.py index 92161ae5dd93d34d898a2027435cc5e55611bcd0..6b9e7a88cea54e825413853cd0266acb86adca2f 100644 --- a/python/paddle/v2/framework/tests/test_recurrent_op.py +++ b/python/paddle/v2/framework/tests/test_recurrent_op.py @@ -197,7 +197,4 @@ class RecurrentGradientOpTest(unittest.TestCase): if __name__ == '__main__': - exit( - 0 - ) # FIXME(yuyang18): InferShape has been removed, this unittest may error unittest.main()