提交 7163dd04 编写于 作者: Q qiaolongfei

revert code

上级 8db3afad
...@@ -28,6 +28,29 @@ using Variable = framework::Variable; ...@@ -28,6 +28,29 @@ using Variable = framework::Variable;
using Tensor = framework::Tensor; using Tensor = framework::Tensor;
using LoDTensor = framework::LoDTensor; using LoDTensor = framework::LoDTensor;
void RecurrentAlgorithm::InferShape(const Scope& scope) const {
auto* input0 = scope.FindVar(arg_->inlinks[0]);
PADDLE_ENFORCE_NOT_NULL(input0);
seq_len_ = input0->GetMutable<LoDTensor>()->dims()[0];
PADDLE_ENFORCE_GT(seq_len_, 0);
CreateScopes(scope);
auto& step_scopes = GetStepScopes(scope);
rnn::SegmentInputs(step_scopes, arg_->inlinks, seq_len_,
true /*infer_shape_mode*/);
InitMemories(step_scopes[0], true /*infer_shape_mode*/);
for (size_t i = 0; i < seq_len_; i++) {
if (i > 0) {
rnn::LinkMemories(step_scopes, arg_->memories, i, -1,
true /*infer_shape_mode*/);
}
(*stepnet_)->InferShape(*step_scopes[i]);
}
rnn::ConcatOutputs(step_scopes, arg_->outlinks, seq_len_,
true /*infer_shape_mode*/);
}
void RecurrentAlgorithm::Run(const Scope& scope, void RecurrentAlgorithm::Run(const Scope& scope,
const platform::DeviceContext& dev_ctx) const { const platform::DeviceContext& dev_ctx) const {
auto step_scopes = GetStepScopes(scope); auto step_scopes = GetStepScopes(scope);
...@@ -179,6 +202,24 @@ void RecurrentGradientAlgorithm::LinkBootMemoryGradients( ...@@ -179,6 +202,24 @@ void RecurrentGradientAlgorithm::LinkBootMemoryGradients(
} }
} }
void RecurrentGradientAlgorithm::InferShape(const Scope& scope) const {
seq_len_ =
scope.FindVar(arg_->inlinks[0])->GetMutable<LoDTensor>()->dims()[0];
auto step_scopes = GetStepScopes(scope);
rnn::SegmentInputs(step_scopes, arg_->inlinks, seq_len_,
true /*infer_shape_mode*/);
for (int step_id = seq_len_ - 1; step_id >= 0; --step_id) {
if (static_cast<size_t>(step_id) != seq_len_ - 1) {
rnn::LinkMemories(step_scopes, arg_->memories, step_id, 1,
true /*infer_shape_mode*/);
}
(*stepnet_)->InferShape(*step_scopes[step_id]);
}
rnn::ConcatOutputs(step_scopes, arg_->outlinks, seq_len_,
true /*infer_shape_mode*/);
LinkBootMemoryGradients(step_scopes[0], true /*infer_shape_mode*/);
}
RecurrentGradientOp::RecurrentGradientOp( RecurrentGradientOp::RecurrentGradientOp(
const std::string& type, const framework::VariableNameMap& inputs, const std::string& type, const framework::VariableNameMap& inputs,
const framework::VariableNameMap& outputs, const framework::VariableNameMap& outputs,
......
...@@ -41,6 +41,11 @@ class RecurrentAlgorithm { ...@@ -41,6 +41,11 @@ class RecurrentAlgorithm {
stepnet_ = stepnet; stepnet_ = stepnet;
} }
/**
* InferShape must be called before Run.
*/
void InferShape(const framework::Scope& scope) const;
protected: protected:
/* /*
* The step scopes will be stored in the father scope as a variable. * The step scopes will be stored in the father scope as a variable.
...@@ -89,6 +94,11 @@ class RecurrentGradientAlgorithm { ...@@ -89,6 +94,11 @@ class RecurrentGradientAlgorithm {
void LinkBootMemoryGradients(framework::Scope* step_scopes, void LinkBootMemoryGradients(framework::Scope* step_scopes,
bool infer_shape_mode) const; bool infer_shape_mode) const;
/**
* InferShape must be called before Run.
*/
void InferShape(const framework::Scope& scope) const;
protected: protected:
inline const std::vector<framework::Scope*>& GetStepScopes( inline const std::vector<framework::Scope*>& GetStepScopes(
const framework::Scope& scope) const { const framework::Scope& scope) const {
...@@ -123,8 +133,13 @@ class RecurrentOp : public framework::OperatorBase { ...@@ -123,8 +133,13 @@ class RecurrentOp : public framework::OperatorBase {
void set_stepnet(std::unique_ptr<OperatorBase> net) { void set_stepnet(std::unique_ptr<OperatorBase> net) {
stepnet_ = std::move(net); stepnet_ = std::move(net);
} }
const OperatorBase& stepnet() const { return *stepnet_; } const OperatorBase& stepnet() const { return *stepnet_; }
void InferShape(const framework::Scope& scope) const {
alg_.InferShape(scope);
}
static const rnn::ArgumentName kArgName; static const rnn::ArgumentName kArgName;
private: private:
...@@ -147,6 +162,10 @@ class RecurrentGradientOp : public framework::OperatorBase { ...@@ -147,6 +162,10 @@ class RecurrentGradientOp : public framework::OperatorBase {
PADDLE_THROW("Not Implemented"); PADDLE_THROW("Not Implemented");
} }
void InferShape(const framework::Scope& scope) const {
alg_.InferShape(scope);
}
void Run(const framework::Scope& scope, void Run(const framework::Scope& scope,
const platform::DeviceContext& dev_ctx) const override { const platform::DeviceContext& dev_ctx) const override {
alg_.Run(scope, dev_ctx); alg_.Run(scope, dev_ctx);
......
...@@ -197,7 +197,4 @@ class RecurrentGradientOpTest(unittest.TestCase): ...@@ -197,7 +197,4 @@ class RecurrentGradientOpTest(unittest.TestCase):
if __name__ == '__main__': if __name__ == '__main__':
exit(
0
) # FIXME(yuyang18): InferShape has been removed, this unittest may error
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册