提交 301a21d8 编写于 作者: Y Yi Wang

cpplint recurrent_op*

上级 5ae7a5f1
...@@ -38,10 +38,10 @@ void SegmentInputs(const std::vector<Scope*>& step_scopes, ...@@ -38,10 +38,10 @@ void SegmentInputs(const std::vector<Scope*>& step_scopes,
"input link [%s] is not in scope.", "input link [%s] is not in scope.",
inlinks[i].external); inlinks[i].external);
Tensor* input = input_var->GetMutable<Tensor>(); Tensor* input = input_var->GetMutable<Tensor>();
DDim dims = input->dims(); framework::DDim dims = input->dims();
PADDLE_ENFORCE(static_cast<size_t>(dims[0]) == seq_len, PADDLE_ENFORCE(static_cast<size_t>(dims[0]) == seq_len,
"all the inlinks must have same length"); "all the inlinks must have same length");
DDim step_dims = slice_ddim(dims, 1, dims.size()); framework::DDim step_dims = slice_ddim(dims, 1, dims.size());
for (size_t j = 0; j < seq_len; j++) { for (size_t j = 0; j < seq_len; j++) {
Tensor* step_input = Tensor* step_input =
step_scopes[j]->NewVar(inlinks[i].internal)->GetMutable<Tensor>(); step_scopes[j]->NewVar(inlinks[i].internal)->GetMutable<Tensor>();
...@@ -64,13 +64,13 @@ void ConcatOutputs(const std::vector<Scope*>& step_scopes, ...@@ -64,13 +64,13 @@ void ConcatOutputs(const std::vector<Scope*>& step_scopes,
outlinks[i].external); outlinks[i].external);
Tensor* output = output_var->GetMutable<Tensor>(); Tensor* output = output_var->GetMutable<Tensor>();
if (infer_shape_mode) { if (infer_shape_mode) {
DDim step_dims = step_scopes[0] framework::DDim step_dims = step_scopes[0]
->FindVar(outlinks[i].internal) ->FindVar(outlinks[i].internal)
->GetMutable<Tensor>() ->GetMutable<Tensor>()
->dims(); ->dims();
std::vector<int> dims_vec = vectorize(step_dims); std::vector<int> dims_vec = vectorize(step_dims);
dims_vec.insert(dims_vec.begin(), seq_len); dims_vec.insert(dims_vec.begin(), seq_len);
output->Resize(make_ddim(dims_vec)); output->Resize(framework::make_ddim(dims_vec));
} else { } else {
output->mutable_data<float>(platform::CPUPlace()); output->mutable_data<float>(platform::CPUPlace());
for (size_t j = 0; j < seq_len; j++) { for (size_t j = 0; j < seq_len; j++) {
......
...@@ -68,7 +68,7 @@ struct ArgumentName { ...@@ -68,7 +68,7 @@ struct ArgumentName {
/** /**
* Prepare inputs for each step net. * Prepare inputs for each step net.
*/ */
void SegmentInputs(const std::vector<Scope*>& step_scopes, void SegmentInputs(const std::vector<framework::Scope*>& step_scopes,
const std::vector<Link>& inlinks, const std::vector<Link>& inlinks,
const size_t seq_len, const size_t seq_len,
bool infer_shape_mode); bool infer_shape_mode);
...@@ -76,12 +76,12 @@ void SegmentInputs(const std::vector<Scope*>& step_scopes, ...@@ -76,12 +76,12 @@ void SegmentInputs(const std::vector<Scope*>& step_scopes,
/** /**
* Process outputs of step nets and merge to variables. * Process outputs of step nets and merge to variables.
*/ */
void ConcatOutputs(const std::vector<Scope*>& step_scopes, void ConcatOutputs(const std::vector<framework::Scope*>& step_scopes,
const std::vector<Link>& outlinks, const std::vector<Link>& outlinks,
const size_t seq_len, const size_t seq_len,
bool infer_shape_mode); bool infer_shape_mode);
void LinkMemories(const std::vector<Scope*>& step_scopes, void LinkMemories(const std::vector<framework::Scope*>& step_scopes,
const std::vector<MemoryAttr>& memories, const std::vector<MemoryAttr>& memories,
const size_t step_id, const size_t step_id,
const int offset, const int offset,
...@@ -101,14 +101,15 @@ void InitArgument(const ArgumentName& name, Argument* arg); ...@@ -101,14 +101,15 @@ void InitArgument(const ArgumentName& name, Argument* arg);
class RecurrentAlgorithm { class RecurrentAlgorithm {
public: public:
void Run(const Scope& scope, const platform::DeviceContext& dev_ctx) const; void Run(const framework::Scope& scope,
const platform::DeviceContext& dev_ctx) const;
void Init(std::unique_ptr<rnn::Argument> arg) { arg_ = std::move(arg); } void Init(std::unique_ptr<rnn::Argument> arg) { arg_ = std::move(arg); }
/** /**
* InferShape must be called before Run. * InferShape must be called before Run.
*/ */
void InferShape(const Scope& scope) const; void InferShape(const framework::Scope& scope) const;
protected: protected:
/* /*
...@@ -117,13 +118,15 @@ protected: ...@@ -117,13 +118,15 @@ protected:
* NOTE the scopes are reused in both the forward and backward, so just * NOTE the scopes are reused in both the forward and backward, so just
* create once and expand its size if more steps need. * create once and expand its size if more steps need.
*/ */
void CreateScopes(const Scope& scope) const; void CreateScopes(const framework::Scope& scope) const;
const std::vector<Scope*>& GetStepScopes(const Scope& scope) const { const std::vector<framework::Scope*>& GetStepScopes(
return *scope.FindVar(arg_->step_scopes)->GetMutable<std::vector<Scope*>>(); const framework::Scope& scope) const {
return *scope.FindVar(arg_->step_scopes)
->GetMutable<std::vector<framework::Scope*>>();
} }
void InitMemories(Scope* step_scopes, bool infer_shape_mode) const; void InitMemories(framework::Scope* step_scopes, bool infer_shape_mode) const;
private: private:
std::unique_ptr<rnn::Argument> arg_; std::unique_ptr<rnn::Argument> arg_;
...@@ -144,18 +147,22 @@ class RecurrentGradientAlgorithm { ...@@ -144,18 +147,22 @@ class RecurrentGradientAlgorithm {
public: public:
void Init(std::unique_ptr<rnn::Argument> arg) { arg_ = std::move(arg); } void Init(std::unique_ptr<rnn::Argument> arg) { arg_ = std::move(arg); }
void Run(const Scope& scope, const platform::DeviceContext& dev_ctx) const; void Run(const framework::Scope& scope,
const platform::DeviceContext& dev_ctx) const;
void LinkBootMemoryGradients(Scope* step_scopes, bool infer_shape_mode) const; void LinkBootMemoryGradients(framework::Scope* step_scopes,
bool infer_shape_mode) const;
/** /**
* InferShape must be called before Run. * InferShape must be called before Run.
*/ */
void InferShape(const Scope& scope) const; void InferShape(const framework::Scope& scope) const;
protected: protected:
inline const std::vector<Scope*>& GetStepScopes(const Scope& scope) const { inline const std::vector<framework::Scope*>& GetStepScopes(
return *scope.FindVar(arg_->step_scopes)->GetMutable<std::vector<Scope*>>(); const framework::Scope& scope) const {
return *scope.FindVar(arg_->step_scopes)
->GetMutable<std::vector<framework::Scope*>>();
} }
private: private:
...@@ -163,16 +170,18 @@ private: ...@@ -163,16 +170,18 @@ private:
mutable size_t seq_len_; mutable size_t seq_len_;
}; };
class RecurrentOp final : public OperatorBase { class RecurrentOp final : public framework::OperatorBase {
public: public:
void Init() override; void Init() override;
/** /**
* InferShape must be called before Run. * InferShape must be called before Run.
*/ */
void InferShape(const Scope& scope) const override { alg_.InferShape(scope); } void InferShape(const framework::Scope& scope) const override {
alg_.InferShape(scope);
}
void Run(const Scope& scope, void Run(const framework::Scope& scope,
const platform::DeviceContext& dev_ctx) const override { const platform::DeviceContext& dev_ctx) const override {
alg_.Run(scope, dev_ctx); alg_.Run(scope, dev_ctx);
} }
...@@ -183,16 +192,18 @@ private: ...@@ -183,16 +192,18 @@ private:
RecurrentAlgorithm alg_; RecurrentAlgorithm alg_;
}; };
class RecurrentGradientOp final : public OperatorBase { class RecurrentGradientOp final : public framework::OperatorBase {
public: public:
void Init() override; void Init() override;
/** /**
* InferShape must be called before Run. * InferShape must be called before Run.
*/ */
void InferShape(const Scope& scope) const override { alg_.InferShape(scope); } void InferShape(const framework::Scope& scope) const override {
alg_.InferShape(scope);
}
void Run(const Scope& scope, void Run(const framework::Scope& scope,
const platform::DeviceContext& dev_ctx) const override { const platform::DeviceContext& dev_ctx) const override {
alg_.Run(scope, dev_ctx); alg_.Run(scope, dev_ctx);
} }
......
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
#include <glog/logging.h> #include <glog/logging.h>
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include "paddle/framework/ddim.h"
#include "paddle/framework/op_registry.h" #include "paddle/framework/op_registry.h"
#include "paddle/framework/operator.h" #include "paddle/framework/operator.h"
#include "paddle/framework/tensor.h" #include "paddle/framework/tensor.h"
...@@ -24,6 +25,9 @@ ...@@ -24,6 +25,9 @@
namespace paddle { namespace paddle {
namespace operators { namespace operators {
using framework::make_ddim;
using framework::DDim;
class RecurrentOpTest : public ::testing::Test { class RecurrentOpTest : public ::testing::Test {
protected: protected:
virtual void SetUp() override { virtual void SetUp() override {
...@@ -72,7 +76,7 @@ protected: ...@@ -72,7 +76,7 @@ protected:
} }
void CreateRNNOp() { void CreateRNNOp() {
OpDesc op_desc; framework::OpDesc op_desc;
op_desc.set_type("recurrent_op"); op_desc.set_type("recurrent_op");
// inlinks 0 // inlinks 0
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册