diff --git a/paddle/operators/recurrent_op.cc b/paddle/operators/recurrent_op.cc index aeb95569b728f53b288a0c9a28220be8b5f7aaa4..2fdaaaf05c5a1428a25946452c97b3f6e2849f2f 100644 --- a/paddle/operators/recurrent_op.cc +++ b/paddle/operators/recurrent_op.cc @@ -38,10 +38,10 @@ void SegmentInputs(const std::vector& step_scopes, "input link [%s] is not in scope.", inlinks[i].external); Tensor* input = input_var->GetMutable(); - DDim dims = input->dims(); + framework::DDim dims = input->dims(); PADDLE_ENFORCE(static_cast(dims[0]) == seq_len, "all the inlinks must have same length"); - DDim step_dims = slice_ddim(dims, 1, dims.size()); + framework::DDim step_dims = slice_ddim(dims, 1, dims.size()); for (size_t j = 0; j < seq_len; j++) { Tensor* step_input = step_scopes[j]->NewVar(inlinks[i].internal)->GetMutable(); @@ -64,13 +64,13 @@ void ConcatOutputs(const std::vector& step_scopes, outlinks[i].external); Tensor* output = output_var->GetMutable(); if (infer_shape_mode) { - DDim step_dims = step_scopes[0] - ->FindVar(outlinks[i].internal) - ->GetMutable() - ->dims(); + framework::DDim step_dims = step_scopes[0] + ->FindVar(outlinks[i].internal) + ->GetMutable() + ->dims(); std::vector dims_vec = vectorize(step_dims); dims_vec.insert(dims_vec.begin(), seq_len); - output->Resize(make_ddim(dims_vec)); + output->Resize(framework::make_ddim(dims_vec)); } else { output->mutable_data(platform::CPUPlace()); for (size_t j = 0; j < seq_len; j++) { diff --git a/paddle/operators/recurrent_op.h b/paddle/operators/recurrent_op.h index f859dc333d20ecf04fcb080e459391c8a0122bb1..c5931773d1d601c4f85b35a031c00de5008a28f8 100644 --- a/paddle/operators/recurrent_op.h +++ b/paddle/operators/recurrent_op.h @@ -68,7 +68,7 @@ struct ArgumentName { /** * Prepare inputs for each step net. */ -void SegmentInputs(const std::vector& step_scopes, +void SegmentInputs(const std::vector& step_scopes, const std::vector& inlinks, const size_t seq_len, bool infer_shape_mode); @@ -76,12 +76,12 @@ void SegmentInputs(const std::vector& step_scopes, /** * Process outputs of step nets and merge to variables. */ -void ConcatOutputs(const std::vector& step_scopes, +void ConcatOutputs(const std::vector& step_scopes, const std::vector& outlinks, const size_t seq_len, bool infer_shape_mode); -void LinkMemories(const std::vector& step_scopes, +void LinkMemories(const std::vector& step_scopes, const std::vector& memories, const size_t step_id, const int offset, @@ -101,14 +101,15 @@ void InitArgument(const ArgumentName& name, Argument* arg); class RecurrentAlgorithm { public: - void Run(const Scope& scope, const platform::DeviceContext& dev_ctx) const; + void Run(const framework::Scope& scope, + const platform::DeviceContext& dev_ctx) const; void Init(std::unique_ptr arg) { arg_ = std::move(arg); } /** * InferShape must be called before Run. */ - void InferShape(const Scope& scope) const; + void InferShape(const framework::Scope& scope) const; protected: /* @@ -117,13 +118,15 @@ protected: * NOTE the scopes are reused in both the forward and backward, so just * create once and expand its size if more steps need. */ - void CreateScopes(const Scope& scope) const; + void CreateScopes(const framework::Scope& scope) const; - const std::vector& GetStepScopes(const Scope& scope) const { - return *scope.FindVar(arg_->step_scopes)->GetMutable>(); + const std::vector& GetStepScopes( + const framework::Scope& scope) const { + return *scope.FindVar(arg_->step_scopes) + ->GetMutable>(); } - void InitMemories(Scope* step_scopes, bool infer_shape_mode) const; + void InitMemories(framework::Scope* step_scopes, bool infer_shape_mode) const; private: std::unique_ptr arg_; @@ -144,18 +147,22 @@ class RecurrentGradientAlgorithm { public: void Init(std::unique_ptr arg) { arg_ = std::move(arg); } - void Run(const Scope& scope, const platform::DeviceContext& dev_ctx) const; + void Run(const framework::Scope& scope, + const platform::DeviceContext& dev_ctx) const; - void LinkBootMemoryGradients(Scope* step_scopes, bool infer_shape_mode) const; + void LinkBootMemoryGradients(framework::Scope* step_scopes, + bool infer_shape_mode) const; /** * InferShape must be called before Run. */ - void InferShape(const Scope& scope) const; + void InferShape(const framework::Scope& scope) const; protected: - inline const std::vector& GetStepScopes(const Scope& scope) const { - return *scope.FindVar(arg_->step_scopes)->GetMutable>(); + inline const std::vector& GetStepScopes( + const framework::Scope& scope) const { + return *scope.FindVar(arg_->step_scopes) + ->GetMutable>(); } private: @@ -163,16 +170,18 @@ private: mutable size_t seq_len_; }; -class RecurrentOp final : public OperatorBase { +class RecurrentOp final : public framework::OperatorBase { public: void Init() override; /** * InferShape must be called before Run. */ - void InferShape(const Scope& scope) const override { alg_.InferShape(scope); } + void InferShape(const framework::Scope& scope) const override { + alg_.InferShape(scope); + } - void Run(const Scope& scope, + void Run(const framework::Scope& scope, const platform::DeviceContext& dev_ctx) const override { alg_.Run(scope, dev_ctx); } @@ -183,16 +192,18 @@ private: RecurrentAlgorithm alg_; }; -class RecurrentGradientOp final : public OperatorBase { +class RecurrentGradientOp final : public framework::OperatorBase { public: void Init() override; /** * InferShape must be called before Run. */ - void InferShape(const Scope& scope) const override { alg_.InferShape(scope); } + void InferShape(const framework::Scope& scope) const override { + alg_.InferShape(scope); + } - void Run(const Scope& scope, + void Run(const framework::Scope& scope, const platform::DeviceContext& dev_ctx) const override { alg_.Run(scope, dev_ctx); } diff --git a/paddle/operators/recurrent_op_test.cc b/paddle/operators/recurrent_op_test.cc index 08a6d9fe5681fdea180de2e9361734ade8564775..f450167c83e84c7f38dd5bc1a8debfc895f8ff00 100644 --- a/paddle/operators/recurrent_op_test.cc +++ b/paddle/operators/recurrent_op_test.cc @@ -16,6 +16,7 @@ #include #include +#include "paddle/framework/ddim.h" #include "paddle/framework/op_registry.h" #include "paddle/framework/operator.h" #include "paddle/framework/tensor.h" @@ -24,6 +25,9 @@ namespace paddle { namespace operators { +using framework::make_ddim; +using framework::DDim; + class RecurrentOpTest : public ::testing::Test { protected: virtual void SetUp() override { @@ -72,7 +76,7 @@ protected: } void CreateRNNOp() { - OpDesc op_desc; + framework::OpDesc op_desc; op_desc.set_type("recurrent_op"); // inlinks 0