提交 301a21d8 编写于 作者: Y Yi Wang

cpplint recurrent_op*

上级 5ae7a5f1
......@@ -38,10 +38,10 @@ void SegmentInputs(const std::vector<Scope*>& step_scopes,
"input link [%s] is not in scope.",
inlinks[i].external);
Tensor* input = input_var->GetMutable<Tensor>();
DDim dims = input->dims();
framework::DDim dims = input->dims();
PADDLE_ENFORCE(static_cast<size_t>(dims[0]) == seq_len,
"all the inlinks must have same length");
DDim step_dims = slice_ddim(dims, 1, dims.size());
framework::DDim step_dims = slice_ddim(dims, 1, dims.size());
for (size_t j = 0; j < seq_len; j++) {
Tensor* step_input =
step_scopes[j]->NewVar(inlinks[i].internal)->GetMutable<Tensor>();
......@@ -64,13 +64,13 @@ void ConcatOutputs(const std::vector<Scope*>& step_scopes,
outlinks[i].external);
Tensor* output = output_var->GetMutable<Tensor>();
if (infer_shape_mode) {
DDim step_dims = step_scopes[0]
framework::DDim step_dims = step_scopes[0]
->FindVar(outlinks[i].internal)
->GetMutable<Tensor>()
->dims();
std::vector<int> dims_vec = vectorize(step_dims);
dims_vec.insert(dims_vec.begin(), seq_len);
output->Resize(make_ddim(dims_vec));
output->Resize(framework::make_ddim(dims_vec));
} else {
output->mutable_data<float>(platform::CPUPlace());
for (size_t j = 0; j < seq_len; j++) {
......
......@@ -68,7 +68,7 @@ struct ArgumentName {
/**
* Prepare inputs for each step net.
*/
void SegmentInputs(const std::vector<Scope*>& step_scopes,
void SegmentInputs(const std::vector<framework::Scope*>& step_scopes,
const std::vector<Link>& inlinks,
const size_t seq_len,
bool infer_shape_mode);
......@@ -76,12 +76,12 @@ void SegmentInputs(const std::vector<Scope*>& step_scopes,
/**
* Process outputs of step nets and merge to variables.
*/
void ConcatOutputs(const std::vector<Scope*>& step_scopes,
void ConcatOutputs(const std::vector<framework::Scope*>& step_scopes,
const std::vector<Link>& outlinks,
const size_t seq_len,
bool infer_shape_mode);
void LinkMemories(const std::vector<Scope*>& step_scopes,
void LinkMemories(const std::vector<framework::Scope*>& step_scopes,
const std::vector<MemoryAttr>& memories,
const size_t step_id,
const int offset,
......@@ -101,14 +101,15 @@ void InitArgument(const ArgumentName& name, Argument* arg);
class RecurrentAlgorithm {
public:
void Run(const Scope& scope, const platform::DeviceContext& dev_ctx) const;
void Run(const framework::Scope& scope,
const platform::DeviceContext& dev_ctx) const;
void Init(std::unique_ptr<rnn::Argument> arg) { arg_ = std::move(arg); }
/**
* InferShape must be called before Run.
*/
void InferShape(const Scope& scope) const;
void InferShape(const framework::Scope& scope) const;
protected:
/*
......@@ -117,13 +118,15 @@ protected:
* NOTE the scopes are reused in both the forward and backward, so just
* create once and expand its size if more steps need.
*/
void CreateScopes(const Scope& scope) const;
void CreateScopes(const framework::Scope& scope) const;
const std::vector<Scope*>& GetStepScopes(const Scope& scope) const {
return *scope.FindVar(arg_->step_scopes)->GetMutable<std::vector<Scope*>>();
const std::vector<framework::Scope*>& GetStepScopes(
const framework::Scope& scope) const {
return *scope.FindVar(arg_->step_scopes)
->GetMutable<std::vector<framework::Scope*>>();
}
void InitMemories(Scope* step_scopes, bool infer_shape_mode) const;
void InitMemories(framework::Scope* step_scopes, bool infer_shape_mode) const;
private:
std::unique_ptr<rnn::Argument> arg_;
......@@ -144,18 +147,22 @@ class RecurrentGradientAlgorithm {
public:
void Init(std::unique_ptr<rnn::Argument> arg) { arg_ = std::move(arg); }
void Run(const Scope& scope, const platform::DeviceContext& dev_ctx) const;
void Run(const framework::Scope& scope,
const platform::DeviceContext& dev_ctx) const;
void LinkBootMemoryGradients(Scope* step_scopes, bool infer_shape_mode) const;
void LinkBootMemoryGradients(framework::Scope* step_scopes,
bool infer_shape_mode) const;
/**
* InferShape must be called before Run.
*/
void InferShape(const Scope& scope) const;
void InferShape(const framework::Scope& scope) const;
protected:
inline const std::vector<Scope*>& GetStepScopes(const Scope& scope) const {
return *scope.FindVar(arg_->step_scopes)->GetMutable<std::vector<Scope*>>();
inline const std::vector<framework::Scope*>& GetStepScopes(
const framework::Scope& scope) const {
return *scope.FindVar(arg_->step_scopes)
->GetMutable<std::vector<framework::Scope*>>();
}
private:
......@@ -163,16 +170,18 @@ private:
mutable size_t seq_len_;
};
class RecurrentOp final : public OperatorBase {
class RecurrentOp final : public framework::OperatorBase {
public:
void Init() override;
/**
* InferShape must be called before Run.
*/
void InferShape(const Scope& scope) const override { alg_.InferShape(scope); }
void InferShape(const framework::Scope& scope) const override {
alg_.InferShape(scope);
}
void Run(const Scope& scope,
void Run(const framework::Scope& scope,
const platform::DeviceContext& dev_ctx) const override {
alg_.Run(scope, dev_ctx);
}
......@@ -183,16 +192,18 @@ private:
RecurrentAlgorithm alg_;
};
class RecurrentGradientOp final : public OperatorBase {
class RecurrentGradientOp final : public framework::OperatorBase {
public:
void Init() override;
/**
* InferShape must be called before Run.
*/
void InferShape(const Scope& scope) const override { alg_.InferShape(scope); }
void InferShape(const framework::Scope& scope) const override {
alg_.InferShape(scope);
}
void Run(const Scope& scope,
void Run(const framework::Scope& scope,
const platform::DeviceContext& dev_ctx) const override {
alg_.Run(scope, dev_ctx);
}
......
......@@ -16,6 +16,7 @@
#include <glog/logging.h>
#include <gtest/gtest.h>
#include "paddle/framework/ddim.h"
#include "paddle/framework/op_registry.h"
#include "paddle/framework/operator.h"
#include "paddle/framework/tensor.h"
......@@ -24,6 +25,9 @@
namespace paddle {
namespace operators {
using framework::make_ddim;
using framework::DDim;
class RecurrentOpTest : public ::testing::Test {
protected:
virtual void SetUp() override {
......@@ -72,7 +76,7 @@ protected:
}
void CreateRNNOp() {
OpDesc op_desc;
framework::OpDesc op_desc;
op_desc.set_type("recurrent_op");
// inlinks 0
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册