提交 5d134a03 编写于 作者: Y Yu Yang

Refine remove std::shared_ptr in Scope

* Make interface of Operator to `const Scope&`
上级 d3ddf050
...@@ -43,7 +43,7 @@ class NetOp : public OperatorBase { ...@@ -43,7 +43,7 @@ class NetOp : public OperatorBase {
* Infer all the operators' input and output variables' shapes, will be called * Infer all the operators' input and output variables' shapes, will be called
* before every mini-batch * before every mini-batch
*/ */
void InferShape(const std::shared_ptr<Scope>& scope) const override { void InferShape(const Scope& scope) const override {
for (auto& op : ops_) { for (auto& op : ops_) {
op->InferShape(scope); op->InferShape(scope);
} }
...@@ -56,7 +56,7 @@ class NetOp : public OperatorBase { ...@@ -56,7 +56,7 @@ class NetOp : public OperatorBase {
* scope will be used instead. If no OpContext is provicded, default context * scope will be used instead. If no OpContext is provicded, default context
* will be used. * will be used.
*/ */
void Run(const std::shared_ptr<Scope>& scope, void Run(const Scope& scope,
const platform::DeviceContext& dev_ctx) const override { const platform::DeviceContext& dev_ctx) const override {
for (auto& op : ops_) { for (auto& op : ops_) {
op->Run(scope, dev_ctx); op->Run(scope, dev_ctx);
......
...@@ -16,11 +16,10 @@ static int run_cnt = 0; ...@@ -16,11 +16,10 @@ static int run_cnt = 0;
class TestOp : public OperatorBase { class TestOp : public OperatorBase {
public: public:
void InferShape( void InferShape(const framework::Scope& scope) const override {
const std::shared_ptr<framework::Scope>& scope) const override {
++infer_shape_cnt; ++infer_shape_cnt;
} }
void Run(const std::shared_ptr<framework::Scope>& scope, void Run(const framework::Scope& scope,
const paddle::platform::DeviceContext& dev_ctx) const override { const paddle::platform::DeviceContext& dev_ctx) const override {
++run_cnt; ++run_cnt;
} }
...@@ -62,7 +61,7 @@ TEST(OpKernel, all) { ...@@ -62,7 +61,7 @@ TEST(OpKernel, all) {
ASSERT_EQ(1UL, tmp_idx.size()); ASSERT_EQ(1UL, tmp_idx.size());
ASSERT_EQ("y", net->outputs_[tmp_idx[0]]); ASSERT_EQ("y", net->outputs_[tmp_idx[0]]);
auto scope = std::make_shared<Scope>(); Scope scope;
platform::CPUDeviceContext dev_ctx; platform::CPUDeviceContext dev_ctx;
net->InferShape(scope); net->InferShape(scope);
......
...@@ -7,9 +7,9 @@ namespace paddle { ...@@ -7,9 +7,9 @@ namespace paddle {
namespace framework { namespace framework {
class CosineOp : public OperatorBase { class CosineOp : public OperatorBase {
public: public:
void Run(const std::shared_ptr<Scope>& scope, void Run(const Scope& scope,
const platform::DeviceContext& dev_ctx) const override {} const platform::DeviceContext& dev_ctx) const override {}
void InferShape(const std::shared_ptr<Scope>& scope) const override {} void InferShape(const Scope& scope) const override {}
}; };
class CosineOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker { class CosineOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
...@@ -27,8 +27,8 @@ class CosineOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker { ...@@ -27,8 +27,8 @@ class CosineOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
class MyTestOp : public OperatorBase { class MyTestOp : public OperatorBase {
public: public:
void InferShape(const std::shared_ptr<Scope>& scope) const override {} void InferShape(const Scope& scope) const override {}
void Run(const std::shared_ptr<Scope>& scope, void Run(const Scope& scope,
const platform::DeviceContext& dev_ctx) const override {} const platform::DeviceContext& dev_ctx) const override {}
}; };
...@@ -69,7 +69,7 @@ TEST(OpRegistry, CreateOp) { ...@@ -69,7 +69,7 @@ TEST(OpRegistry, CreateOp) {
std::shared_ptr<paddle::framework::OperatorBase> op = std::shared_ptr<paddle::framework::OperatorBase> op =
paddle::framework::OpRegistry::CreateOp(op_desc); paddle::framework::OpRegistry::CreateOp(op_desc);
auto scope = std::make_shared<paddle::framework::Scope>(); paddle::framework::Scope scope;
paddle::platform::CPUDeviceContext dev_ctx; paddle::platform::CPUDeviceContext dev_ctx;
op->Run(scope, dev_ctx); op->Run(scope, dev_ctx);
float scale_get = op->GetAttr<float>("scale"); float scale_get = op->GetAttr<float>("scale");
...@@ -111,7 +111,7 @@ TEST(OpRegistry, DefaultValue) { ...@@ -111,7 +111,7 @@ TEST(OpRegistry, DefaultValue) {
std::shared_ptr<paddle::framework::OperatorBase> op = std::shared_ptr<paddle::framework::OperatorBase> op =
paddle::framework::OpRegistry::CreateOp(op_desc); paddle::framework::OpRegistry::CreateOp(op_desc);
auto scope = std::make_shared<paddle::framework::Scope>(); paddle::framework::Scope scope;
paddle::platform::CPUDeviceContext dev_ctx; paddle::platform::CPUDeviceContext dev_ctx;
op->Run(scope, dev_ctx); op->Run(scope, dev_ctx);
ASSERT_EQ(op->GetAttr<float>("scale"), 1.0); ASSERT_EQ(op->GetAttr<float>("scale"), 1.0);
...@@ -173,7 +173,7 @@ TEST(OpRegistry, CustomChecker) { ...@@ -173,7 +173,7 @@ TEST(OpRegistry, CustomChecker) {
SetInputFormat(&op_desc); SetInputFormat(&op_desc);
auto op = paddle::framework::OpRegistry::CreateOp(op_desc); auto op = paddle::framework::OpRegistry::CreateOp(op_desc);
paddle::platform::CPUDeviceContext dev_ctx; paddle::platform::CPUDeviceContext dev_ctx;
auto scope = std::make_shared<paddle::framework::Scope>(); paddle::framework::Scope scope;
op->Run(scope, dev_ctx); op->Run(scope, dev_ctx);
int test_attr = op->GetAttr<int>("test_attr"); int test_attr = op->GetAttr<int>("test_attr");
ASSERT_EQ(test_attr, 4); ASSERT_EQ(test_attr, 4);
......
...@@ -84,10 +84,10 @@ class OperatorBase { ...@@ -84,10 +84,10 @@ class OperatorBase {
/// InferShape infer the size of Variables used by this Operator with /// InferShape infer the size of Variables used by this Operator with
/// information inside scope /// information inside scope
virtual void InferShape(const std::shared_ptr<Scope>& scope) const = 0; virtual void InferShape(const Scope& scope) const = 0;
/// Net will call this function to Run an op. /// Net will call this function to Run an op.
virtual void Run(const std::shared_ptr<Scope>& scope, virtual void Run(const Scope& scope,
const platform::DeviceContext& dev_ctx) const = 0; const platform::DeviceContext& dev_ctx) const = 0;
virtual bool IsNetOp() const { return false; } virtual bool IsNetOp() const { return false; }
...@@ -114,24 +114,24 @@ class OperatorBase { ...@@ -114,24 +114,24 @@ class OperatorBase {
class KernelContext { class KernelContext {
public: public:
KernelContext(const OperatorBase* op, const std::shared_ptr<Scope>& scope, KernelContext(const OperatorBase* op, const Scope& scope,
const platform::DeviceContext& device_context) const platform::DeviceContext& device_context)
: op_(*op), scope_(scope), device_context_(device_context) {} : op_(*op), scope_(scope), device_context_(device_context) {}
const Variable* Input(int index) const { const Variable* Input(int index) const {
return scope_->FindVar(op_.inputs_[index]); return scope_.FindVar(op_.inputs_[index]);
} }
Variable* Output(int index) const { Variable* Output(int index) const {
return scope_->FindVar(op_.outputs_[index]); return scope_.FindVar(op_.outputs_[index]);
} }
const Variable* Input(const std::string& name) const { const Variable* Input(const std::string& name) const {
return scope_->FindVar(op_.Input(name)); return scope_.FindVar(op_.Input(name));
} }
const Variable* Output(const std::string& name) const { const Variable* Output(const std::string& name) const {
return scope_->FindVar(op_.Output(name)); return scope_.FindVar(op_.Output(name));
} }
const std::vector<const Variable*> Inputs(const std::string& name) const { const std::vector<const Variable*> Inputs(const std::string& name) const {
...@@ -139,7 +139,7 @@ class KernelContext { ...@@ -139,7 +139,7 @@ class KernelContext {
std::vector<const Variable*> res; std::vector<const Variable*> res;
std::transform( std::transform(
names.begin(), names.end(), res.begin(), names.begin(), names.end(), res.begin(),
[this](const std::string& name) { return scope_->FindVar(name); }); [this](const std::string& name) { return scope_.FindVar(name); });
return res; return res;
} }
...@@ -148,7 +148,7 @@ class KernelContext { ...@@ -148,7 +148,7 @@ class KernelContext {
std::vector<const Variable*> res; std::vector<const Variable*> res;
std::transform( std::transform(
names.begin(), names.end(), res.begin(), names.begin(), names.end(), res.begin(),
[this](const std::string& name) { return scope_->FindVar(name); }); [this](const std::string& name) { return scope_.FindVar(name); });
return res; return res;
} }
...@@ -160,7 +160,7 @@ class KernelContext { ...@@ -160,7 +160,7 @@ class KernelContext {
platform::Place GetPlace() const { return device_context_.GetPlace(); } platform::Place GetPlace() const { return device_context_.GetPlace(); }
const OperatorBase& op_; const OperatorBase& op_;
const std::shared_ptr<Scope>& scope_; const Scope& scope_;
const platform::DeviceContext& device_context_; const platform::DeviceContext& device_context_;
}; };
...@@ -216,7 +216,7 @@ class OperatorWithKernel : public OperatorBase { ...@@ -216,7 +216,7 @@ class OperatorWithKernel : public OperatorBase {
using OpKernelMap = using OpKernelMap =
std::unordered_map<OpKernelKey, std::unique_ptr<OpKernel>, OpKernelHash>; std::unordered_map<OpKernelKey, std::unique_ptr<OpKernel>, OpKernelHash>;
void Run(const std::shared_ptr<Scope>& scope, void Run(const Scope& scope,
const platform::DeviceContext& dev_ctx) const final { const platform::DeviceContext& dev_ctx) const final {
auto& opKernel = AllOpKernels().at(type_).at(OpKernelKey(dev_ctx)); auto& opKernel = AllOpKernels().at(type_).at(OpKernelKey(dev_ctx));
opKernel->Compute(KernelContext(this, scope, dev_ctx)); opKernel->Compute(KernelContext(this, scope, dev_ctx));
...@@ -228,7 +228,7 @@ class OperatorWithKernel : public OperatorBase { ...@@ -228,7 +228,7 @@ class OperatorWithKernel : public OperatorBase {
return g_all_op_kernels; return g_all_op_kernels;
} }
void InferShape(const std::shared_ptr<Scope>& scope) const final { void InferShape(const Scope& scope) const final {
std::vector<const Tensor*> ins; std::vector<const Tensor*> ins;
VarNamesToTensors(scope, inputs_, &ins); VarNamesToTensors(scope, inputs_, &ins);
std::vector<Tensor*> outs; std::vector<Tensor*> outs;
...@@ -238,13 +238,13 @@ class OperatorWithKernel : public OperatorBase { ...@@ -238,13 +238,13 @@ class OperatorWithKernel : public OperatorBase {
private: private:
template <typename T> template <typename T>
void VarNamesToTensors(const std::shared_ptr<Scope>& scope, void VarNamesToTensors(const Scope& scope,
const std::vector<std::string>& var_names, const std::vector<std::string>& var_names,
std::vector<T>* container) const { std::vector<T>* container) const {
container->reserve(var_names.size()); container->reserve(var_names.size());
VarToTensor<T> convert; VarToTensor<T> convert;
for (auto& name : var_names) { for (auto& name : var_names) {
auto var = scope->FindVar(name); auto var = scope.FindVar(name);
if (var != nullptr) { if (var != nullptr) {
container->push_back(convert(var)); container->push_back(convert(var));
} else { } else {
......
...@@ -24,15 +24,15 @@ static int op_run_num = 0; ...@@ -24,15 +24,15 @@ static int op_run_num = 0;
class OpWithoutKernelTest : public OperatorBase { class OpWithoutKernelTest : public OperatorBase {
public: public:
void Init() override { x = 1; } void Init() override { x = 1; }
void InferShape(const std::shared_ptr<Scope>& scope) const override {} void InferShape(const Scope& scope) const override {}
void Run(const std::shared_ptr<Scope>& scope, void Run(const Scope& scope,
const platform::DeviceContext& dev_ctx) const override { const platform::DeviceContext& dev_ctx) const override {
op_run_num++; op_run_num++;
ASSERT_EQ((int)inputs_.size(), 1); ASSERT_EQ((int)inputs_.size(), 1);
ASSERT_EQ((int)outputs_.size(), 1); ASSERT_EQ((int)outputs_.size(), 1);
ASSERT_EQ(scope->FindVar(inputs_[0]), nullptr); ASSERT_EQ(scope.FindVar(inputs_[0]), nullptr);
ASSERT_EQ(x, 1); ASSERT_EQ(x, 1);
ASSERT_NE(scope->FindVar(outputs_[0]), nullptr); ASSERT_NE(scope.FindVar(outputs_[0]), nullptr);
} }
public: public:
...@@ -68,10 +68,10 @@ TEST(OperatorBase, all) { ...@@ -68,10 +68,10 @@ TEST(OperatorBase, all) {
attr->set_f(3.14); attr->set_f(3.14);
paddle::platform::CPUDeviceContext device_context; paddle::platform::CPUDeviceContext device_context;
auto scope = std::make_shared<paddle::framework::Scope>(); paddle::framework::Scope scope;
auto op = paddle::framework::OpRegistry::CreateOp(op_desc); auto op = paddle::framework::OpRegistry::CreateOp(op_desc);
scope->NewVar("OUT1"); scope.NewVar("OUT1");
ASSERT_EQ(paddle::framework::op_run_num, 0); ASSERT_EQ(paddle::framework::op_run_num, 0);
op->Run(scope, device_context); op->Run(scope, device_context);
ASSERT_EQ(paddle::framework::op_run_num, 1); ASSERT_EQ(paddle::framework::op_run_num, 1);
...@@ -117,12 +117,12 @@ class CPUKernelTest : public OpKernel { ...@@ -117,12 +117,12 @@ class CPUKernelTest : public OpKernel {
class OperatorMultiInputsTest : public OperatorBase { class OperatorMultiInputsTest : public OperatorBase {
public: public:
void Init() override { x = 1; } void Init() override { x = 1; }
void InferShape(const std::shared_ptr<Scope>& scope) const override {} void InferShape(const Scope& scope) const override {}
void Run(const std::shared_ptr<Scope>& scope, void Run(const Scope& scope,
const platform::DeviceContext& dev_ctx) const override { const platform::DeviceContext& dev_ctx) const override {
ASSERT_EQ(scope->FindVar(inputs_[0]), nullptr); ASSERT_EQ(scope.FindVar(inputs_[0]), nullptr);
ASSERT_EQ(x, 1); ASSERT_EQ(x, 1);
ASSERT_NE(scope->FindVar(outputs_[0]), nullptr); ASSERT_NE(scope.FindVar(outputs_[0]), nullptr);
ASSERT_EQ(Input("x"), "IN1"); ASSERT_EQ(Input("x"), "IN1");
ASSERT_EQ(Input("y"), "OUT1"); ASSERT_EQ(Input("y"), "OUT1");
} }
...@@ -186,7 +186,7 @@ TEST(OpKernel, all) { ...@@ -186,7 +186,7 @@ TEST(OpKernel, all) {
attr->set_f(3.14); attr->set_f(3.14);
paddle::platform::CPUDeviceContext cpu_device_context; paddle::platform::CPUDeviceContext cpu_device_context;
auto scope = std::make_shared<paddle::framework::Scope>(); paddle::framework::Scope scope;
auto op = paddle::framework::OpRegistry::CreateOp(op_desc); auto op = paddle::framework::OpRegistry::CreateOp(op_desc);
ASSERT_EQ(paddle::framework::cpu_kernel_run_num, 0); ASSERT_EQ(paddle::framework::cpu_kernel_run_num, 0);
...@@ -232,7 +232,7 @@ TEST(OpKernel, multi_inputs) { ...@@ -232,7 +232,7 @@ TEST(OpKernel, multi_inputs) {
output_format->Add(2); // y1 output_format->Add(2); // y1
paddle::platform::CPUDeviceContext cpu_device_context; paddle::platform::CPUDeviceContext cpu_device_context;
auto scope = std::make_shared<Scope>(); paddle::framework::Scope scope;
auto op = paddle::framework::OpRegistry::CreateOp(op_desc); auto op = paddle::framework::OpRegistry::CreateOp(op_desc);
op->Run(scope, cpu_device_context); op->Run(scope, cpu_device_context);
......
...@@ -19,11 +19,11 @@ namespace paddle { ...@@ -19,11 +19,11 @@ namespace paddle {
namespace framework { namespace framework {
Scope::~Scope() { Scope::~Scope() {
DropKids();
for (auto& kv : vars_) delete kv.second; for (auto& kv : vars_) delete kv.second;
for (Scope* s : kids_) delete s;
} }
Scope& Scope::NewScope() { Scope& Scope::NewScope() const {
kids_.push_back(new Scope(this)); kids_.push_back(new Scope(this));
return *kids_.back(); return *kids_.back();
} }
...@@ -49,7 +49,7 @@ Variable* Scope::FindVar(const std::string& name) const { ...@@ -49,7 +49,7 @@ Variable* Scope::FindVar(const std::string& name) const {
return (parent_ == nullptr) ? nullptr : parent_->FindVar(name); return (parent_ == nullptr) ? nullptr : parent_->FindVar(name);
} }
Scope* Scope::FindScope(const Variable* var) { const Scope* Scope::FindScope(const Variable* var) const {
for (auto& kv : vars_) { for (auto& kv : vars_) {
if (kv.second == var) { if (kv.second == var) {
return this; return this;
...@@ -57,6 +57,10 @@ Scope* Scope::FindScope(const Variable* var) { ...@@ -57,6 +57,10 @@ Scope* Scope::FindScope(const Variable* var) {
} }
return (parent_ == nullptr) ? nullptr : parent_->FindScope(var); return (parent_ == nullptr) ? nullptr : parent_->FindScope(var);
} }
void Scope::DropKids() {
for (Scope* s : kids_) delete s;
kids_.clear();
}
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -15,8 +15,8 @@ limitations under the License. */ ...@@ -15,8 +15,8 @@ limitations under the License. */
#pragma once #pragma once
#include <list> #include <list>
#include <map>
#include <string> #include <string>
#include <unordered_map>
#include "paddle/framework/variable.h" #include "paddle/framework/variable.h"
...@@ -38,30 +38,39 @@ class Scope { ...@@ -38,30 +38,39 @@ class Scope {
Scope() {} Scope() {}
~Scope(); ~Scope();
// Create a sub-scope. Returns a reference other than a pointer so // Disable Copy, Assign, Move.
// to prevent from manual deletion. Scope(const Scope& other) = delete;
Scope& NewScope(); Scope& operator=(const Scope& other) = delete;
Scope(Scope&& other) = delete;
// Create a variable with given name if it doesn't exist. /// Create a sub-scope. Returns a reference other than a pointer so
/// to prevent from manual deletion.
/// Mark it to const because that new kid scope cannot change parent scope.
Scope& NewScope() const;
/// Create a variable with given name if it doesn't exist.
Variable* NewVar(const std::string& name); Variable* NewVar(const std::string& name);
// Create a variable with a scope-unique name. /// Create a variable with a scope-unique name.
Variable* NewVar(); Variable* NewVar();
// Find a variable in the scope or any of its ancestors. Returns /// Find a variable in the scope or any of its ancestors. Returns
// nullptr if cannot find. /// nullptr if cannot find.
Variable* FindVar(const std::string& name) const; Variable* FindVar(const std::string& name) const;
// Find the scope or an ancestor scope that contains the given variable. /// Find the scope or an ancestor scope that contains the given variable.
Scope* FindScope(const Variable* var); const Scope* FindScope(const Variable* var) const;
/// Drop all kids scopes belonged to this scope.
void DropKids();
private: private:
// Call Scope::NewScope for a sub-scope. // Call Scope::NewScope for a sub-scope.
explicit Scope(Scope* parent) : parent_(parent) {} explicit Scope(Scope const* parent) : parent_(parent) {}
std::map<std::string, Variable*> vars_; std::unordered_map<std::string, Variable*> vars_;
std::list<Scope*> kids_; mutable std::list<Scope*> kids_;
Scope* parent_{nullptr}; Scope const* parent_{nullptr};
}; };
} // namespace framework } // namespace framework
......
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
#include <typeindex> #include <typeindex>
#include <typeinfo> #include <typeinfo>
#include "paddle/platform/assert.h" #include "paddle/platform/enforce.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -25,7 +25,7 @@ class Variable { ...@@ -25,7 +25,7 @@ class Variable {
public: public:
template <typename T> template <typename T>
const T& Get() const { const T& Get() const {
PADDLE_ASSERT(IsType<T>()); PADDLE_ENFORCE(IsType<T>(), "Variable must be type %s", typeid(T).name());
return *static_cast<const T*>(holder_->Ptr()); return *static_cast<const T*>(holder_->Ptr());
} }
......
...@@ -27,7 +27,7 @@ namespace operators { ...@@ -27,7 +27,7 @@ namespace operators {
namespace rnn { namespace rnn {
void SegmentInputs(std::vector<std::shared_ptr<Scope>>& step_scopes, void SegmentInputs(const std::vector<Scope*>& step_scopes,
const std::vector<Link>& inlinks, const std::vector<Link>& inlinks,
const size_t seq_len) { const size_t seq_len) {
PADDLE_ENFORCE(!inlinks.empty(), "no in links are provided."); PADDLE_ENFORCE(!inlinks.empty(), "no in links are provided.");
...@@ -47,7 +47,7 @@ void SegmentInputs(std::vector<std::shared_ptr<Scope>>& step_scopes, ...@@ -47,7 +47,7 @@ void SegmentInputs(std::vector<std::shared_ptr<Scope>>& step_scopes,
} }
} }
void ConcatOutputs(std::vector<std::shared_ptr<Scope>>& step_scopes, void ConcatOutputs(const std::vector<Scope*>& step_scopes,
const std::vector<Link>& outlinks, const std::vector<Link>& outlinks,
const size_t seq_len) { const size_t seq_len) {
for (size_t i = 0; i < outlinks.size(); i++) { for (size_t i = 0; i < outlinks.size(); i++) {
...@@ -75,7 +75,7 @@ void ConcatOutputs(std::vector<std::shared_ptr<Scope>>& step_scopes, ...@@ -75,7 +75,7 @@ void ConcatOutputs(std::vector<std::shared_ptr<Scope>>& step_scopes,
} }
} }
void LinkMemories(std::vector<std::shared_ptr<Scope>>& scopes, void LinkMemories(const std::vector<Scope*>& scopes,
const std::vector<rnn::MemoryAttr>& memories, const std::vector<rnn::MemoryAttr>& memories,
size_t step_id, size_t step_id,
int offset) { int offset) {
...@@ -92,8 +92,8 @@ void LinkMemories(std::vector<std::shared_ptr<Scope>>& scopes, ...@@ -92,8 +92,8 @@ void LinkMemories(std::vector<std::shared_ptr<Scope>>& scopes,
offset, offset,
scopes.size(), scopes.size(),
step_id); step_id);
std::shared_ptr<Scope> scope = scopes[step_id]; auto scope = scopes[step_id];
std::shared_ptr<Scope> linked_scope = scopes[step_id + offset]; auto linked_scope = scopes[step_id + offset];
for (auto& attr : memories) { for (auto& attr : memories) {
auto mem = scope->NewVar(attr.pre_var)->GetMutable<Tensor>(); auto mem = scope->NewVar(attr.pre_var)->GetMutable<Tensor>();
// maybe share variable is better? // maybe share variable is better?
...@@ -169,8 +169,8 @@ void InitArgument(const ArgumentName& name, ...@@ -169,8 +169,8 @@ void InitArgument(const ArgumentName& name,
} // namespace rnn } // namespace rnn
void RecurrentAlgorithm::InferShape(const std::shared_ptr<Scope>& scope) const { void RecurrentAlgorithm::InferShape(const Scope& scope) const {
seq_len_ = scope->FindVar((arg_->inlinks[0]).external) seq_len_ = scope.FindVar((arg_->inlinks[0]).external)
->GetMutable<Tensor>() ->GetMutable<Tensor>()
->dims()[0]; ->dims()[0];
CreateScopes(scope); CreateScopes(scope);
...@@ -185,10 +185,10 @@ void RecurrentAlgorithm::InferShape(const std::shared_ptr<Scope>& scope) const { ...@@ -185,10 +185,10 @@ void RecurrentAlgorithm::InferShape(const std::shared_ptr<Scope>& scope) const {
InitMemories(step_scopes[0]); InitMemories(step_scopes[0]);
PADDLE_ENFORCE(scope->FindVar(arg_->step_net), PADDLE_ENFORCE(scope.FindVar(arg_->step_net) != nullptr,
"stepnet [%s] is not in scope.", "stepnet [%s] is not in scope.",
arg_->step_net); arg_->step_net);
Variable* net = scope->FindVar(arg_->step_net); Variable* net = scope.FindVar(arg_->step_net);
PADDLE_ENFORCE(net != nullptr, "failed to get step net"); PADDLE_ENFORCE(net != nullptr, "failed to get step net");
// If the InferShape is called in OperatorBase's run function, // If the InferShape is called in OperatorBase's run function,
// the rnn op only needs to do InferShape for the first time step // the rnn op only needs to do InferShape for the first time step
...@@ -196,7 +196,7 @@ void RecurrentAlgorithm::InferShape(const std::shared_ptr<Scope>& scope) const { ...@@ -196,7 +196,7 @@ void RecurrentAlgorithm::InferShape(const std::shared_ptr<Scope>& scope) const {
if (i > 0) { if (i > 0) {
rnn::LinkMemories(step_scopes, arg_->memories, i, -1); rnn::LinkMemories(step_scopes, arg_->memories, i, -1);
} }
net->GetMutable<NetOp>()->InferShape(step_scopes[i]); net->GetMutable<NetOp>()->InferShape(*step_scopes[i]);
} }
auto outlinks = arg_->outlinks; auto outlinks = arg_->outlinks;
...@@ -214,51 +214,51 @@ void RecurrentAlgorithm::InferShape(const std::shared_ptr<Scope>& scope) const { ...@@ -214,51 +214,51 @@ void RecurrentAlgorithm::InferShape(const std::shared_ptr<Scope>& scope) const {
} }
} }
void RecurrentAlgorithm::Run(const std::shared_ptr<Scope>& scope, void RecurrentAlgorithm::Run(const Scope& scope,
const platform::DeviceContext& dev_ctx) const { const platform::DeviceContext& dev_ctx) const {
auto step_scopes = GetStepScopes(scope); auto step_scopes = GetStepScopes(scope);
Variable* net = scope->FindVar(arg_->step_net); Variable* net = scope.FindVar(arg_->step_net);
for (size_t step_id = 0; step_id < seq_len_; step_id++) { for (size_t step_id = 0; step_id < seq_len_; step_id++) {
// the link memory is done in InferShape // the link memory is done in InferShape
// maybe remove following code after testing // maybe remove following code after testing
if (step_id > 0) { if (step_id > 0) {
rnn::LinkMemories(step_scopes, arg_->memories, step_id, -1); rnn::LinkMemories(step_scopes, arg_->memories, step_id, -1);
} }
net->GetMutable<NetOp>()->Run(step_scopes[step_id], dev_ctx); net->GetMutable<NetOp>()->Run(*step_scopes[step_id], dev_ctx);
} }
rnn::ConcatOutputs(step_scopes, arg_->outlinks, seq_len_); rnn::ConcatOutputs(step_scopes, arg_->outlinks, seq_len_);
} }
void RecurrentAlgorithm::CreateScopes(std::shared_ptr<Scope> scope) const { void RecurrentAlgorithm::CreateScopes(const Scope& scope) const {
// TODO(xxx) Only two scopes are needed for inference, this case will be // TODO(xxx) Only two scopes are needed for inference, this case will be
// supported later. // supported later.
auto step_scopes = scope->FindVar(arg_->step_scopes) auto step_scopes =
->GetMutable<std::vector<std::shared_ptr<Scope>>>(); scope.FindVar(arg_->step_scopes)->GetMutable<std::vector<Scope*>>();
if (seq_len_ > step_scopes->size()) { if (seq_len_ > step_scopes->size()) {
for (size_t i = step_scopes->size(); i < seq_len_; ++i) { for (size_t i = step_scopes->size(); i < seq_len_; ++i) {
std::shared_ptr<Scope> step_scope = std::make_shared<Scope>(scope); auto& step_scope = scope.NewScope();
// Now all variables in scope must be created outside of op. // Now all variables in scope must be created outside of op.
auto net_op = scope->FindVar(arg_->step_net)->GetMutable<NetOp>(); auto net_op = scope.FindVar(arg_->step_net)->GetMutable<NetOp>();
for (auto& input : net_op->inputs_) { for (auto& input : net_op->inputs_) {
step_scope->NewVar(input); if (!step_scope.FindVar(input)) step_scope.NewVar(input);
} }
for (auto& output : net_op->outputs_) { for (auto& output : net_op->outputs_) {
step_scope->NewVar(output); step_scope.NewVar(output);
} }
step_scopes->push_back(std::make_shared<Scope>(step_scope)); step_scopes->emplace_back(&step_scope);
} }
} }
} }
void RecurrentAlgorithm::InitMemories(std::shared_ptr<Scope> step_scope) const { void RecurrentAlgorithm::InitMemories(Scope* step_scope) const {
for (auto& attr : arg_->memories) { for (auto& attr : arg_->memories) {
Tensor* pre_mem = step_scope->NewVar(attr.pre_var)->GetMutable<Tensor>(); Tensor* pre_mem = step_scope->NewVar(attr.pre_var)->GetMutable<Tensor>();
PADDLE_ENFORCE(step_scope->FindVar(attr.boot_var), PADDLE_ENFORCE(step_scope->FindVar(attr.boot_var) != nullptr,
"memory [%s]'s boot variable [%s] not exists", "memory [%s]'s boot variable [%s] not exists",
attr.var, attr.var,
attr.boot_var); attr.boot_var);
...@@ -328,30 +328,30 @@ public: ...@@ -328,30 +328,30 @@ public:
}; };
void RecurrentGradientAlgorithm::Run( void RecurrentGradientAlgorithm::Run(
const std::shared_ptr<Scope>& scope, const Scope& scope, const platform::DeviceContext& dev_ctx) const {
const platform::DeviceContext& dev_ctx) const {
auto step_scopes = GetStepScopes(scope); auto step_scopes = GetStepScopes(scope);
rnn::SegmentInputs(step_scopes, arg_->inlinks, seq_len_); rnn::SegmentInputs(step_scopes, arg_->inlinks, seq_len_);
PADDLE_ENFORCE(scope->FindVar(arg_->step_net), "step net is not in scope."); PADDLE_ENFORCE(scope.FindVar(arg_->step_net) != nullptr,
Variable* net = scope->FindVar(arg_->step_net); "step net is not in scope.");
Variable* net = scope.FindVar(arg_->step_net);
PADDLE_ENFORCE(net != nullptr, "failed to get step net"); PADDLE_ENFORCE(net != nullptr, "failed to get step net");
for (int step_id = seq_len_ - 1; step_id >= 0; --step_id) { for (int step_id = seq_len_ - 1; step_id >= 0; --step_id) {
if (static_cast<size_t>(step_id) != seq_len_ - 1) { if (static_cast<size_t>(step_id) != seq_len_ - 1) {
rnn::LinkMemories(step_scopes, arg_->memories, step_id, 1); rnn::LinkMemories(step_scopes, arg_->memories, step_id, 1);
} }
net->GetMutable<NetOp>()->Run(step_scopes[step_id], dev_ctx); net->GetMutable<NetOp>()->Run(*step_scopes[step_id], dev_ctx);
} }
LinkBootMemoryGradients(step_scopes[0]); LinkBootMemoryGradients(step_scopes[0]);
rnn::ConcatOutputs(step_scopes, arg_->outlinks, seq_len_); rnn::ConcatOutputs(step_scopes, arg_->outlinks, seq_len_);
} }
void RecurrentGradientAlgorithm::LinkBootMemoryGradients( void RecurrentGradientAlgorithm::LinkBootMemoryGradients(
std::shared_ptr<Scope> step_scope) const { Scope* step_scope) const {
for (auto& attr : arg_->memories) { for (auto& attr : arg_->memories) {
Tensor* mem_grad = step_scope->NewVar(attr.var)->GetMutable<Tensor>(); Tensor* mem_grad = step_scope->NewVar(attr.var)->GetMutable<Tensor>();
PADDLE_ENFORCE(mem_grad != nullptr, PADDLE_ENFORCE(mem_grad != nullptr,
"boot_tensor should be retrieved before"); "boot_tensor should be retrieved before");
PADDLE_ENFORCE(step_scope->FindVar(attr.boot_var), PADDLE_ENFORCE(step_scope->FindVar(attr.boot_var) != nullptr,
"memory [%s]'s boot variable [%s] not exists", "memory [%s]'s boot variable [%s] not exists",
attr.var, attr.var,
attr.boot_var); attr.boot_var);
...@@ -361,23 +361,23 @@ void RecurrentGradientAlgorithm::LinkBootMemoryGradients( ...@@ -361,23 +361,23 @@ void RecurrentGradientAlgorithm::LinkBootMemoryGradients(
} }
} }
void RecurrentGradientAlgorithm::InferShape( void RecurrentGradientAlgorithm::InferShape(const Scope& scope) const {
const std::shared_ptr<Scope>& scope) const { seq_len_ = scope.FindVar((arg_->inlinks[0]).external)
seq_len_ = scope->FindVar((arg_->inlinks[0]).external)
->GetMutable<Tensor>() ->GetMutable<Tensor>()
->dims()[0]; ->dims()[0];
auto step_scopes = GetStepScopes(scope); auto step_scopes = GetStepScopes(scope);
rnn::SegmentInputs(step_scopes, arg_->inlinks, seq_len_); rnn::SegmentInputs(step_scopes, arg_->inlinks, seq_len_);
PADDLE_ENFORCE(scope->FindVar(arg_->step_net), "step net is not in scope."); PADDLE_ENFORCE(scope.FindVar(arg_->step_net) != nullptr,
Variable* net = scope->FindVar(arg_->step_net); "step net is not in scope.");
Variable* net = scope.FindVar(arg_->step_net);
PADDLE_ENFORCE(net != nullptr, "failed to get step net"); PADDLE_ENFORCE(net != nullptr, "failed to get step net");
for (int step_id = seq_len_ - 1; step_id >= 0; --step_id) { for (int step_id = seq_len_ - 1; step_id >= 0; --step_id) {
if (static_cast<size_t>(step_id) != seq_len_ - 1) { if (static_cast<size_t>(step_id) != seq_len_ - 1) {
rnn::LinkMemories(step_scopes, arg_->memories, step_id, 1); rnn::LinkMemories(step_scopes, arg_->memories, step_id, 1);
} }
net->GetMutable<NetOp>()->InferShape(step_scopes[step_id]); net->GetMutable<NetOp>()->InferShape(*step_scopes[step_id]);
} }
auto outlinks = arg_->outlinks; auto outlinks = arg_->outlinks;
......
...@@ -70,18 +70,18 @@ struct ArgumentName { ...@@ -70,18 +70,18 @@ struct ArgumentName {
/** /**
* Prepare inputs for each step net. * Prepare inputs for each step net.
*/ */
void SegmentInputs(std::vector<std::shared_ptr<Scope>>& step_scopes, void SegmentInputs(const std::vector<Scope*>& step_scopes,
const std::vector<Link>& inlinks, const std::vector<Link>& inlinks,
const size_t seq_len); const size_t seq_len);
/** /**
* Process outputs of step nets and merge to variables. * Process outputs of step nets and merge to variables.
*/ */
void ConcatOutputs(std::vector<std::shared_ptr<Scope>>& step_scopes, void ConcatOutputs(const std::vector<Scope*>& step_scopes,
const std::vector<Link>& outlinks, const std::vector<Link>& outlinks,
const size_t seq_len); const size_t seq_len);
void LinkMemories(std::vector<std::shared_ptr<Scope>>& step_scopes, void LinkMemories(const std::vector<Scope*>& step_scopes,
const std::vector<MemoryAttr>& memories, const std::vector<MemoryAttr>& memories,
size_t step_id, size_t step_id,
int offset); int offset);
...@@ -100,15 +100,14 @@ void InitArgument(const ArgumentName& name, Argument* arg); ...@@ -100,15 +100,14 @@ void InitArgument(const ArgumentName& name, Argument* arg);
class RecurrentAlgorithm { class RecurrentAlgorithm {
public: public:
void Run(const std::shared_ptr<Scope>& scope, void Run(const Scope& scope, const platform::DeviceContext& dev_ctx) const;
const platform::DeviceContext& dev_ctx) const;
void Init(std::unique_ptr<rnn::Argument> arg) { arg_ = std::move(arg); } void Init(std::unique_ptr<rnn::Argument> arg) { arg_ = std::move(arg); }
/** /**
* InferShape must be called before Run. * InferShape must be called before Run.
*/ */
void InferShape(const std::shared_ptr<Scope>& scope) const; void InferShape(const Scope& scope) const;
protected: protected:
/* /*
...@@ -117,15 +116,13 @@ protected: ...@@ -117,15 +116,13 @@ protected:
* NOTE the scopes are reused in both the forward and backward, so just * NOTE the scopes are reused in both the forward and backward, so just
* create once and expand its size if more steps need. * create once and expand its size if more steps need.
*/ */
void CreateScopes(std::shared_ptr<Scope> scope) const; void CreateScopes(const Scope& scope) const;
inline const std::vector<std::shared_ptr<Scope>>& GetStepScopes( const std::vector<Scope*>& GetStepScopes(const Scope& scope) const {
std::shared_ptr<Scope> scope) const { return *scope.FindVar(arg_->step_scopes)->GetMutable<std::vector<Scope*>>();
return *(scope->FindVar(arg_->step_scopes))
->GetMutable<std::vector<std::shared_ptr<Scope>>>();
} }
void InitMemories(std::shared_ptr<Scope> step_scopes) const; void InitMemories(Scope* step_scopes) const;
private: private:
std::unique_ptr<rnn::Argument> arg_; std::unique_ptr<rnn::Argument> arg_;
...@@ -146,21 +143,18 @@ class RecurrentGradientAlgorithm { ...@@ -146,21 +143,18 @@ class RecurrentGradientAlgorithm {
public: public:
void Init(std::unique_ptr<rnn::Argument> arg) { arg_ = std::move(arg); } void Init(std::unique_ptr<rnn::Argument> arg) { arg_ = std::move(arg); }
void Run(const std::shared_ptr<Scope>& scope, void Run(const Scope& scope, const platform::DeviceContext& dev_ctx) const;
const platform::DeviceContext& dev_ctx) const;
void LinkBootMemoryGradients(std::shared_ptr<Scope> step_scopes) const; void LinkBootMemoryGradients(Scope* step_scopes) const;
/** /**
* InferShape must be called before Run. * InferShape must be called before Run.
*/ */
void InferShape(const std::shared_ptr<Scope>& scope) const; void InferShape(const Scope& scope) const;
protected: protected:
inline const std::vector<std::shared_ptr<Scope>>& GetStepScopes( inline const std::vector<Scope*>& GetStepScopes(const Scope& scope) const {
std::shared_ptr<Scope> scope) const { return *scope.FindVar(arg_->step_scopes)->GetMutable<std::vector<Scope*>>();
return *(scope->FindVar(arg_->step_scopes))
->GetMutable<std::vector<std::shared_ptr<Scope>>>();
} }
private: private:
...@@ -175,11 +169,11 @@ public: ...@@ -175,11 +169,11 @@ public:
/** /**
* InferShape must be called before Run. * InferShape must be called before Run.
*/ */
virtual void InferShape(const std::shared_ptr<Scope>& scope) const override { virtual void InferShape(const Scope& scope) const override {
alg_.InferShape(scope); alg_.InferShape(scope);
} }
virtual void Run(const std::shared_ptr<Scope>& scope, virtual void Run(const Scope& scope,
const platform::DeviceContext& dev_ctx) const override { const platform::DeviceContext& dev_ctx) const override {
alg_.Run(scope, dev_ctx); alg_.Run(scope, dev_ctx);
} }
...@@ -197,11 +191,11 @@ public: ...@@ -197,11 +191,11 @@ public:
/** /**
* InferShape must be called before Run. * InferShape must be called before Run.
*/ */
virtual void InferShape(const std::shared_ptr<Scope>& scope) const override { virtual void InferShape(const Scope& scope) const override {
alg_.InferShape(scope); alg_.InferShape(scope);
} }
virtual void Run(const std::shared_ptr<Scope>& scope, virtual void Run(const Scope& scope,
const platform::DeviceContext& dev_ctx) const override { const platform::DeviceContext& dev_ctx) const override {
alg_.Run(scope, dev_ctx); alg_.Run(scope, dev_ctx);
} }
......
...@@ -34,41 +34,40 @@ protected: ...@@ -34,41 +34,40 @@ protected:
virtual void TearDown() override {} virtual void TearDown() override {}
void CreateGlobalVariables() { void CreateGlobalVariables() {
scope_ = std::make_shared<Scope>();
// create input, and init content // create input, and init content
LOG(INFO) << "create global variable x"; LOG(INFO) << "create global variable x";
for (auto inlink : std::vector<std::string>{"x", "x0", "x1", "h"}) { for (auto inlink : std::vector<std::string>{"x", "x0", "x1", "h"}) {
Variable* x = scope_->NewVar(inlink); Variable* x = scope_.NewVar(inlink);
DDim dims = make_ddim(std::vector<int>{ DDim dims = make_ddim(std::vector<int>{
10 /*sent size*/, 20 /*batch size*/, 30 /*input dim*/}); 10 /*sent size*/, 20 /*batch size*/, 30 /*input dim*/});
x->GetMutable<Tensor>()->mutable_data<float>(dims, platform::CPUPlace()); x->GetMutable<Tensor>()->mutable_data<float>(dims, platform::CPUPlace());
} }
// create output alias just for test // create output alias just for test
for (auto inlink : std::vector<std::string>{"h@alias"}) { for (auto inlink : std::vector<std::string>{"h@alias"}) {
Variable* x = scope_->NewVar(inlink); Variable* x = scope_.NewVar(inlink);
DDim dims = DDim dims =
make_ddim(std::vector<int>{20 /*batch size*/, 30 /*input dim*/}); make_ddim(std::vector<int>{20 /*batch size*/, 30 /*input dim*/});
x->GetMutable<Tensor>()->mutable_data<float>(dims, platform::CPUPlace()); x->GetMutable<Tensor>()->mutable_data<float>(dims, platform::CPUPlace());
} }
LOG(INFO) << "create global variable w"; LOG(INFO) << "create global variable w";
Variable* w = scope_->NewVar("rnn/w"); Variable* w = scope_.NewVar("rnn/w");
w->GetMutable<Tensor>()->mutable_data<float>( w->GetMutable<Tensor>()->mutable_data<float>(
make_ddim(std::vector<int>{30, 30}), platform::CPUPlace()); make_ddim(std::vector<int>{30, 30}), platform::CPUPlace());
for (auto boot : std::vector<std::string>{"x_boot", "h_boot"}) { for (auto boot : std::vector<std::string>{"x_boot", "h_boot"}) {
LOG(INFO) << "create global variable " << boot; LOG(INFO) << "create global variable " << boot;
Variable* h_boot = scope_->NewVar(boot); Variable* h_boot = scope_.NewVar(boot);
h_boot->GetMutable<Tensor>()->mutable_data<float>( h_boot->GetMutable<Tensor>()->mutable_data<float>(
make_ddim(std::vector<int>{20 /*batch size*/, 30 /*input dim*/}), make_ddim(std::vector<int>{20 /*batch size*/, 30 /*input dim*/}),
platform::CPUPlace()); platform::CPUPlace());
} }
LOG(INFO) << "create variable step_scopes"; LOG(INFO) << "create variable step_scopes";
scope_->NewVar("step_scopes"); scope_.NewVar("step_scopes");
LOG(INFO) << "create variable h"; LOG(INFO) << "create variable h";
scope_->NewVar("h"); scope_.NewVar("h");
} }
void CreateRNNOp() { void CreateRNNOp() {
...@@ -150,7 +149,7 @@ protected: ...@@ -150,7 +149,7 @@ protected:
void CreateStepNet() { void CreateStepNet() {
LOG(INFO) << "create variable step_net"; LOG(INFO) << "create variable step_net";
Variable* var = scope_->NewVar("step_net"); Variable* var = scope_.NewVar("step_net");
auto net = var->GetMutable<NetOp>(); auto net = var->GetMutable<NetOp>();
// rnn/s is net's input or output? // rnn/s is net's input or output?
net->inputs_ = {"rnn/h@pre", "rnn/w", "rnn/x"}; net->inputs_ = {"rnn/h@pre", "rnn/w", "rnn/x"};
...@@ -164,7 +163,7 @@ protected: ...@@ -164,7 +163,7 @@ protected:
} }
// father scope // father scope
std::shared_ptr<Scope> scope_; Scope scope_;
std::shared_ptr<OperatorBase> rnn_op_; std::shared_ptr<OperatorBase> rnn_op_;
}; };
...@@ -191,66 +190,64 @@ protected: ...@@ -191,66 +190,64 @@ protected:
virtual void TearDown() override {} virtual void TearDown() override {}
void CreateGlobalVariables() { void CreateGlobalVariables() {
scope_ = std::make_shared<Scope>();
// inputs: x // inputs: x
LOG(INFO) << "create global variable x"; LOG(INFO) << "create global variable x";
Variable* x = scope_->NewVar("x"); Variable* x = scope_.NewVar("x");
DDim dims = DDim dims =
make_ddim({10 /*sent size*/, 20 /*batch size*/, 30 /*input dim*/}); make_ddim({10 /*sent size*/, 20 /*batch size*/, 30 /*input dim*/});
x->GetMutable<Tensor>()->mutable_data<float>(dims, platform::CPUPlace()); x->GetMutable<Tensor>()->mutable_data<float>(dims, platform::CPUPlace());
// inputs: h_boot // inputs: h_boot
LOG(INFO) << "create global variable h_boot"; LOG(INFO) << "create global variable h_boot";
Variable* h_boot = scope_->NewVar("h_boot"); Variable* h_boot = scope_.NewVar("h_boot");
h_boot->GetMutable<Tensor>()->mutable_data<float>( h_boot->GetMutable<Tensor>()->mutable_data<float>(
make_ddim({20 /*batch size*/, 30 /*input dim*/}), platform::CPUPlace()); make_ddim({20 /*batch size*/, 30 /*input dim*/}), platform::CPUPlace());
// inputs: w // inputs: w
LOG(INFO) << "create global variable w"; LOG(INFO) << "create global variable w";
Variable* w = scope_->NewVar("rnn/w"); Variable* w = scope_.NewVar("rnn/w");
w->GetMutable<Tensor>()->mutable_data<float>(make_ddim({30, 30}), w->GetMutable<Tensor>()->mutable_data<float>(make_ddim({30, 30}),
platform::CPUPlace()); platform::CPUPlace());
// inputs: h_grad // inputs: h_grad
LOG(INFO) << "create variable h_grad"; LOG(INFO) << "create variable h_grad";
Variable* dh = scope_->NewVar("h_grad"); Variable* dh = scope_.NewVar("h_grad");
dh->GetMutable<Tensor>()->mutable_data<float>(make_ddim({10, 20, 30}), dh->GetMutable<Tensor>()->mutable_data<float>(make_ddim({10, 20, 30}),
platform::CPUPlace()); platform::CPUPlace());
// inputs: step_scopes // inputs: step_scopes
LOG(INFO) << "create variable step_scopes"; LOG(INFO) << "create variable step_scopes";
scope_->NewVar("step_scopes"); scope_.NewVar("step_scopes");
// inputs: step_net // inputs: step_net
LOG(INFO) << "create variable step_net"; LOG(INFO) << "create variable step_net";
scope_->NewVar("step_net"); scope_.NewVar("step_net");
// outputs: w_grad // outputs: w_grad
LOG(INFO) << "create global variable w_grad"; LOG(INFO) << "create global variable w_grad";
scope_->NewVar("rnn/w_grad"); scope_.NewVar("rnn/w_grad");
// outputs: x_grad // outputs: x_grad
LOG(INFO) << "create global variable x_grad"; LOG(INFO) << "create global variable x_grad";
scope_->NewVar("x_grad"); scope_.NewVar("x_grad");
// outputs: h_boot_grad // outputs: h_boot_grad
LOG(INFO) << "create global variable h_boot_grad"; LOG(INFO) << "create global variable h_boot_grad";
scope_->NewVar("h_boot_grad"); scope_.NewVar("h_boot_grad");
} }
void CreateStepScopes() { void CreateStepScopes() {
std::vector<std::shared_ptr<Scope>>* step_scopes = auto step_scopes =
scope_->FindVar("step_scopes") scope_.FindVar("step_scopes")->GetMutable<std::vector<Scope*>>();
->GetMutable<std::vector<std::shared_ptr<Scope>>>();
for (int i = 0; i < 10; ++i) { for (int i = 0; i < 10; ++i) {
auto scope = std::make_shared<Scope>(scope_); auto& scope = scope_.NewScope();
auto pre_t = scope->NewVar("rnn/pre_h")->GetMutable<Tensor>(); auto pre_t = scope.NewVar("rnn/pre_h")->GetMutable<Tensor>();
pre_t->mutable_data<float>(make_ddim({20, 30}), platform::CPUPlace()); pre_t->mutable_data<float>({20, 30}, platform::CPUPlace());
auto tensor = scope->NewVar("rnn/h")->GetMutable<Tensor>(); auto tensor = scope.NewVar("rnn/h")->GetMutable<Tensor>();
tensor->mutable_data<float>(make_ddim({20, 30}), platform::CPUPlace()); tensor->mutable_data<float>({20, 30}, platform::CPUPlace());
// for unit test of ConcatOutputs // for unit test of ConcatOutputs
auto xg = scope->NewVar("rnn/x_grad")->GetMutable<Tensor>(); auto xg = scope.NewVar("rnn/x_grad")->GetMutable<Tensor>();
xg->mutable_data<float>(make_ddim({20, 30}), platform::CPUPlace()); xg->mutable_data<float>({20, 30}, platform::CPUPlace());
step_scopes->push_back(scope); step_scopes->emplace_back(&scope);
} }
// last time step // last time step
auto g = (*step_scopes)[9]->NewVar("rnn/h_pre_grad")->GetMutable<Tensor>(); auto g = (*step_scopes)[9]->NewVar("rnn/h_pre_grad")->GetMutable<Tensor>();
g->mutable_data<float>(make_ddim({20, 30}), platform::CPUPlace()); g->mutable_data<float>({20, 30}, platform::CPUPlace());
} }
void CreateRNNGradientAlgorithm() { void CreateRNNGradientAlgorithm() {
...@@ -278,7 +275,7 @@ protected: ...@@ -278,7 +275,7 @@ protected:
void CreateStepNet() { void CreateStepNet() {
LOG(INFO) << "create variable step_net"; LOG(INFO) << "create variable step_net";
Variable* var = scope_->NewVar("step_net"); Variable* var = scope_.NewVar("step_net");
auto net = var->GetMutable<NetOp>(); auto net = var->GetMutable<NetOp>();
net->AddOp(OpRegistry::CreateOp("mul", net->AddOp(OpRegistry::CreateOp("mul",
{"rnn/h_pre", "rnn/w", "rnn/s_grad"}, {"rnn/h_pre", "rnn/w", "rnn/s_grad"},
...@@ -298,9 +295,8 @@ protected: ...@@ -298,9 +295,8 @@ protected:
rnn::Link inlink; rnn::Link inlink;
inlink.external = "x"; inlink.external = "x";
inlink.internal = "rnn/x"; inlink.internal = "rnn/x";
std::vector<std::shared_ptr<Scope>>* step_scopes = auto step_scopes =
scope_->FindVar("step_scopes") scope_.FindVar("step_scopes")->GetMutable<std::vector<Scope*>>();
->GetMutable<std::vector<std::shared_ptr<Scope>>>();
rnn::SegmentInputs(*step_scopes, std::vector<rnn::Link>{inlink}, 10); rnn::SegmentInputs(*step_scopes, std::vector<rnn::Link>{inlink}, 10);
} }
...@@ -312,15 +308,14 @@ protected: ...@@ -312,15 +308,14 @@ protected:
mem_attr.boot_var = "boot_h"; mem_attr.boot_var = "boot_h";
std::vector<rnn::MemoryAttr> memories; std::vector<rnn::MemoryAttr> memories;
memories.push_back(mem_attr); memories.push_back(mem_attr);
std::vector<std::shared_ptr<Scope>>* step_scopes = auto step_scopes =
scope_->FindVar("step_scopes") scope_.FindVar("step_scopes")->GetMutable<std::vector<Scope*>>();
->GetMutable<std::vector<std::shared_ptr<Scope>>>();
for (int i = 1; i < 10; ++i) { for (int i = 1; i < 10; ++i) {
rnn::LinkMemories(*step_scopes, memories, i, -1); rnn::LinkMemories(*step_scopes, memories, i, -1);
} }
} }
std::shared_ptr<Scope> scope_; Scope scope_;
RecurrentGradientAlgorithm rnn_grad_algo_; RecurrentGradientAlgorithm rnn_grad_algo_;
}; };
...@@ -339,14 +334,14 @@ TEST(RecurrentOp, LinkMemories) { ...@@ -339,14 +334,14 @@ TEST(RecurrentOp, LinkMemories) {
// create and init step scopes // create and init step scopes
int len = 10; int len = 10;
std::vector<std::shared_ptr<Scope>> step_scopes; std::vector<Scope*> step_scopes;
for (int i = 0; i < len; ++i) { for (int i = 0; i < len; ++i) {
auto scope = std::make_shared<Scope>(); auto scope = new Scope();
scope->NewVar("pre_h"); scope->NewVar("pre_h");
auto tensor = scope->NewVar("h")->GetMutable<Tensor>(); auto tensor = scope->NewVar("h")->GetMutable<Tensor>();
float* data = tensor->mutable_data<float>(make_ddim({15, 20}), CPUPlace()); float* data = tensor->mutable_data<float>({15, 20}, CPUPlace());
for (int i = 0; i < 15 * 20; ++i) { for (int j = 0; j < 15 * 20; ++j) {
data[i] = rand() * (1. / (double)RAND_MAX); data[j] = rand() * (1. / (double)RAND_MAX);
} }
step_scopes.push_back(scope); step_scopes.push_back(scope);
} }
...@@ -388,7 +383,17 @@ TEST(RecurrentOp, LinkMemories) { ...@@ -388,7 +383,17 @@ TEST(RecurrentOp, LinkMemories) {
ASSERT_FLOAT_EQ(a[i], b[i]); ASSERT_FLOAT_EQ(a[i], b[i]);
} }
} }
for (auto s : step_scopes) {
delete s;
}
} }
USE_OP(add_two); USE_OP(add_two);
USE_OP(mul); USE_OP(mul);
// int main() {
// //! TODO(yuyang18): Temporary disable this unit-test because implementation
// //! error.
// return 0;
//}
\ No newline at end of file
...@@ -14,6 +14,7 @@ limitations under the License. */ ...@@ -14,6 +14,7 @@ limitations under the License. */
#pragma once #pragma once
#include <glog/logging.h>
#include <paddle/string/printf.h> #include <paddle/string/printf.h>
#include <sstream> #include <sstream>
#include <stdexcept> #include <stdexcept>
...@@ -127,17 +128,6 @@ inline typename std::enable_if<sizeof...(Args) != 0, void>::type throw_on_error( ...@@ -127,17 +128,6 @@ inline typename std::enable_if<sizeof...(Args) != 0, void>::type throw_on_error(
#endif // PADDLE_ONLY_CPU #endif // PADDLE_ONLY_CPU
template <typename T, typename... Args>
inline typename std::enable_if<std::is_pointer<T>::value, void>::type
throw_on_error(T stat, const Args&... args) {
if (stat == nullptr) {
return;
} else {
throw std::runtime_error("Pointer value is nullptr: " +
string::Sprintf(args...));
}
}
template <typename T> template <typename T>
inline void throw_on_error(T e) { inline void throw_on_error(T e) {
throw_on_error(e, ""); throw_on_error(e, "");
......
...@@ -102,11 +102,18 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -102,11 +102,18 @@ All parameter, weight, gradient are variables in Paddle.
}, },
py::return_value_policy::reference); py::return_value_policy::reference);
py::class_<pd::Scope, std::shared_ptr<pd::Scope>>(m, "Scope") py::class_<pd::Scope>(m, "Scope", "")
.def(py::init<const std::shared_ptr<pd::Scope>&>()) .def("new_var",
.def("get_var", &pd::Scope::FindVar, py::return_value_policy::reference) [](pd::Scope& self, const std::string& name) -> pd::Variable* {
.def("create_var", &pd::Scope::NewVar, py::return_value_policy::reference) return self.NewVar(name);
.def("get_var_name", &pd::Scope::FindVarName); },
py::return_value_policy::reference)
.def("find_var", &pd::Scope::FindVar, py::return_value_policy::reference)
.def(py::init<>())
.def("new_scope",
[](pd::Scope& self) -> pd::Scope* { return &self.NewScope(); },
py::return_value_policy::reference)
.def("drop_kids", &pd::Scope::DropKids);
//! @note: Be careful! PyBind will return std::string as an unicode, not //! @note: Be careful! PyBind will return std::string as an unicode, not
//! Python str. If you want a str object, you should cast them in Python. //! Python str. If you want a str object, you should cast them in Python.
......
...@@ -5,7 +5,7 @@ Default scope function. ...@@ -5,7 +5,7 @@ Default scope function.
thread-local stack of Scope. Top of that stack is current scope, the bottom thread-local stack of Scope. Top of that stack is current scope, the bottom
of that stack is all scopes' parent. of that stack is all scopes' parent.
Invoking `create_var/get_var` can `create/get` variable in current scope. Invoking `new_var/find_var` can `new/find` variable in current scope.
Invoking `enter_local_scope/leave_local_scope` can create or destroy local Invoking `enter_local_scope/leave_local_scope` can create or destroy local
scope. scope.
...@@ -19,8 +19,8 @@ import threading ...@@ -19,8 +19,8 @@ import threading
__tl_scope__ = threading.local() __tl_scope__ = threading.local()
__all__ = [ __all__ = [
'get_cur_scope', 'enter_local_scope', 'leave_local_scope', 'create_var', 'get_cur_scope', 'enter_local_scope', 'leave_local_scope', 'new_var',
'get_var', 'scoped_function' 'find_var', 'scoped_function'
] ]
...@@ -33,7 +33,7 @@ def get_cur_scope(): ...@@ -33,7 +33,7 @@ def get_cur_scope():
if cur_scope_stack is None: if cur_scope_stack is None:
__tl_scope__.cur_scope = list() __tl_scope__.cur_scope = list()
if len(__tl_scope__.cur_scope) == 0: if len(__tl_scope__.cur_scope) == 0:
__tl_scope__.cur_scope.append(paddle.v2.framework.core.Scope(None)) __tl_scope__.cur_scope.append(paddle.v2.framework.core.Scope())
return __tl_scope__.cur_scope[-1] return __tl_scope__.cur_scope[-1]
...@@ -42,7 +42,7 @@ def enter_local_scope(): ...@@ -42,7 +42,7 @@ def enter_local_scope():
Enter a new local scope Enter a new local scope
""" """
cur_scope = get_cur_scope() cur_scope = get_cur_scope()
new_scope = paddle.v2.framework.core.Scope(cur_scope) new_scope = cur_scope.new_scope()
__tl_scope__.cur_scope.append(new_scope) __tl_scope__.cur_scope.append(new_scope)
...@@ -51,20 +51,21 @@ def leave_local_scope(): ...@@ -51,20 +51,21 @@ def leave_local_scope():
Leave local scope Leave local scope
""" """
__tl_scope__.cur_scope.pop() __tl_scope__.cur_scope.pop()
get_cur_scope().drop_kids()
def create_var(name): def new_var(name):
""" """
create variable in current scope. create variable in current scope.
""" """
return get_cur_scope().create_var(name) return get_cur_scope().new_var(name)
def get_var(name): def find_var(name):
""" """
get variable in current scope. get variable in current scope.
""" """
return get_cur_scope().get_var(name) return get_cur_scope().find_var(name)
def scoped_function(func): def scoped_function(func):
......
import paddle.v2.framework.core as core import paddle.v2.framework.core as core
from paddle.v2.framework.create_op_creation_methods import op_creations from paddle.v2.framework.create_op_creation_methods import op_creations
from default_scope_funcs import create_var, get_var, get_cur_scope from default_scope_funcs import new_var, find_var, get_cur_scope
__all__ = ['Network'] # Only expose Network __all__ = ['Network'] # Only expose Network
...@@ -29,12 +29,15 @@ class NetworkFunctor(object): ...@@ -29,12 +29,15 @@ class NetworkFunctor(object):
if ipt in kwargs: if ipt in kwargs:
var = kwargs[ipt] var = kwargs[ipt]
if isinstance(var, basestring): if isinstance(var, basestring):
var = create_var(var) tmp = new_var(var)
self.net.var_names[tmp] = var
var = tmp
if not isinstance(var, core.Variable): if not isinstance(var, core.Variable):
raise TypeError( raise TypeError(
"Input of op creation must be string or variable") "Input of op creation must be string or variable")
kwargs[ipt] = get_cur_scope().get_var_name(var) kwargs[ipt] = self.net.var_names[var]
notemp_outputs = self.func.all_not_temp_output_args notemp_outputs = self.func.all_not_temp_output_args
...@@ -49,17 +52,20 @@ class NetworkFunctor(object): ...@@ -49,17 +52,20 @@ class NetworkFunctor(object):
if opt in kwargs: if opt in kwargs:
var = kwargs[opt] var = kwargs[opt]
if isinstance(var, basestring): if isinstance(var, basestring):
var = create_var(var) tmp = new_var(var)
self.net.var_names[tmp] = var
var = tmp
if not isinstance(var, core.Variable): if not isinstance(var, core.Variable):
raise TypeError( raise TypeError(
"Output of op creation must be string or variable") "Output of op creation must be string or variable")
kwargs[opt] = get_cur_scope().get_var_name(var) kwargs[opt] = self.net.var_names[var]
op = self.func(**kwargs) op = self.func(**kwargs)
self.net.net.add_op(op) self.net.net.add_op(op)
lst = [get_var(kwargs[opt]) for opt in notemp_outputs] lst = [find_var(kwargs[opt]) for opt in notemp_outputs]
if len(lst) == 1: if len(lst) == 1:
return lst[0] return lst[0]
elif len(lst) == 0: elif len(lst) == 0:
...@@ -89,6 +95,7 @@ class Network(object): ...@@ -89,6 +95,7 @@ class Network(object):
self.net = core.Net.create() self.net = core.Net.create()
funcs = (func_name for func_name in dir(op_creations) funcs = (func_name for func_name in dir(op_creations)
if not func_name.startswith("__")) if not func_name.startswith("__"))
self.var_names = dict()
# TODO(yuyang18): This code can work, but do not generate a good # TODO(yuyang18): This code can work, but do not generate a good
# docstring, try to give a better way generate function in runtime # docstring, try to give a better way generate function in runtime
......
...@@ -24,13 +24,13 @@ class OpTestMeta(type): ...@@ -24,13 +24,13 @@ class OpTestMeta(type):
func = getattr(creation.op_creations, self.type, None) func = getattr(creation.op_creations, self.type, None)
self.assertIsNotNone(func) self.assertIsNotNone(func)
scope = core.Scope(None) scope = core.Scope()
kwargs = dict() kwargs = dict()
for in_name in func.all_input_args: for in_name in func.all_input_args:
if hasattr(self, in_name): if hasattr(self, in_name):
kwargs[in_name] = in_name kwargs[in_name] = in_name
var = scope.create_var(in_name).get_tensor() var = scope.new_var(in_name).get_tensor()
arr = getattr(self, in_name) arr = getattr(self, in_name)
var.set_dims(arr.shape) var.set_dims(arr.shape)
var.set(arr) var.set(arr)
...@@ -40,7 +40,7 @@ class OpTestMeta(type): ...@@ -40,7 +40,7 @@ class OpTestMeta(type):
for out_name in func.all_output_args: for out_name in func.all_output_args:
if hasattr(self, out_name): if hasattr(self, out_name):
kwargs[out_name] = out_name kwargs[out_name] = out_name
scope.create_var(out_name).get_tensor() scope.new_var(out_name).get_tensor()
for attr_name in func.all_attr_args: for attr_name in func.all_attr_args:
if hasattr(self, attr_name): if hasattr(self, attr_name):
...@@ -54,7 +54,7 @@ class OpTestMeta(type): ...@@ -54,7 +54,7 @@ class OpTestMeta(type):
op.run(scope, ctx) op.run(scope, ctx)
for out_name in func.all_output_args: for out_name in func.all_output_args:
actual = numpy.array(scope.get_var(out_name).get_tensor()) actual = numpy.array(scope.find_var(out_name).get_tensor())
expect = getattr(self, out_name) expect = getattr(self, out_name)
# TODO(qijun) The default decimal is 7, but numpy.dot and eigen.mul # TODO(qijun) The default decimal is 7, but numpy.dot and eigen.mul
# has some diff, and could not pass unittest. So I set decimal 3 here. # has some diff, and could not pass unittest. So I set decimal 3 here.
......
...@@ -7,19 +7,19 @@ class TestDefaultScopeFuncs(unittest.TestCase): ...@@ -7,19 +7,19 @@ class TestDefaultScopeFuncs(unittest.TestCase):
self.assertIsNotNone(get_cur_scope()) self.assertIsNotNone(get_cur_scope())
def test_none_variable(self): def test_none_variable(self):
self.assertIsNone(get_var("test")) self.assertIsNone(find_var("test"))
def test_create_var_get_var(self): def test_create_var_get_var(self):
var_a = create_var("var_a") var_a = new_var("var_a")
self.assertIsNotNone(var_a) self.assertIsNotNone(var_a)
self.assertIsNotNone(get_cur_scope().get_var('var_a')) self.assertIsNotNone(get_cur_scope().find_var('var_a'))
enter_local_scope() enter_local_scope()
self.assertIsNotNone(get_cur_scope().get_var('var_a')) self.assertIsNotNone(get_cur_scope().find_var('var_a'))
leave_local_scope() leave_local_scope()
def test_var_get_int(self): def test_var_get_int(self):
def __new_scope__(): def __new_scope__():
i = create_var("var_i") i = new_var("var_i")
self.assertFalse(i.is_int()) self.assertFalse(i.is_int())
i.set_int(10) i.set_int(10)
self.assertTrue(i.is_int()) self.assertTrue(i.is_int())
......
...@@ -6,13 +6,13 @@ import paddle.v2.framework.create_op_creation_methods as creation ...@@ -6,13 +6,13 @@ import paddle.v2.framework.create_op_creation_methods as creation
class TestFc(unittest.TestCase): class TestFc(unittest.TestCase):
def test_fc(self): def test_fc(self):
scope = core.Scope(None) scope = core.Scope()
x = scope.create_var("X") x = scope.new_var("X")
x_tensor = x.get_tensor() x_tensor = x.get_tensor()
x_tensor.set_dims([1000, 784]) x_tensor.set_dims([1000, 784])
x_tensor.alloc_float() x_tensor.alloc_float()
w = scope.create_var("W") w = scope.new_var("W")
w_tensor = w.get_tensor() w_tensor = w.get_tensor()
w_tensor.set_dims([784, 100]) w_tensor.set_dims([784, 100])
w_tensor.alloc_float() w_tensor.alloc_float()
...@@ -25,10 +25,10 @@ class TestFc(unittest.TestCase): ...@@ -25,10 +25,10 @@ class TestFc(unittest.TestCase):
op = creation.op_creations.fc(X="X", Y="Y", W="W") op = creation.op_creations.fc(X="X", Y="Y", W="W")
for out in op.outputs(): for out in op.outputs():
if scope.get_var(out) is None: if scope.find_var(out) is None:
scope.create_var(out).get_tensor() scope.new_var(out).get_tensor()
tensor = scope.get_var("Y").get_tensor() tensor = scope.find_var("Y").get_tensor()
op.infer_shape(scope) op.infer_shape(scope)
self.assertEqual([1000, 100], tensor.shape()) self.assertEqual([1000, 100], tensor.shape())
......
...@@ -5,29 +5,29 @@ import unittest ...@@ -5,29 +5,29 @@ import unittest
class TestScope(unittest.TestCase): class TestScope(unittest.TestCase):
def test_create_destroy(self): def test_create_destroy(self):
paddle_c = paddle.v2.framework.core paddle_c = paddle.v2.framework.core
scope = paddle_c.Scope(None) scope = paddle_c.Scope()
self.assertIsNotNone(scope) self.assertIsNotNone(scope)
scope_with_parent = paddle_c.Scope(scope) scope_with_parent = scope.new_scope()
self.assertIsNotNone(scope_with_parent) self.assertIsNotNone(scope_with_parent)
def test_none_variable(self): def test_none_variable(self):
paddle_c = paddle.v2.framework.core paddle_c = paddle.v2.framework.core
scope = paddle_c.Scope(None) scope = paddle_c.Scope()
self.assertIsNone(scope.get_var("test")) self.assertIsNone(scope.find_var("test"))
def test_create_var_get_var(self): def test_create_var_get_var(self):
paddle_c = paddle.v2.framework.core paddle_c = paddle.v2.framework.core
scope = paddle_c.Scope(None) scope = paddle_c.Scope()
var_a = scope.create_var("var_a") var_a = scope.new_var("var_a")
self.assertIsNotNone(var_a) self.assertIsNotNone(var_a)
self.assertIsNotNone(scope.get_var('var_a')) self.assertIsNotNone(scope.find_var('var_a'))
scope2 = paddle_c.Scope(scope) scope2 = scope.new_scope()
self.assertIsNotNone(scope2.get_var('var_a')) self.assertIsNotNone(scope2.find_var('var_a'))
def test_var_get_int(self): def test_var_get_int(self):
paddle_c = paddle.v2.framework.core paddle_c = paddle.v2.framework.core
scope = paddle_c.Scope(None) scope = paddle_c.Scope()
var = scope.create_var("test_int") var = scope.new_var("test_int")
var.set_int(10) var.set_int(10)
self.assertTrue(var.is_int()) self.assertTrue(var.is_int())
self.assertEqual(10, var.get_int()) self.assertEqual(10, var.get_int())
......
...@@ -5,8 +5,8 @@ import numpy ...@@ -5,8 +5,8 @@ import numpy
class TestScope(unittest.TestCase): class TestScope(unittest.TestCase):
def test_int_tensor(self): def test_int_tensor(self):
scope = core.Scope(None) scope = core.Scope()
var = scope.create_var("test_tensor") var = scope.new_var("test_tensor")
tensor = var.get_tensor() tensor = var.get_tensor()
tensor.set_dims([1000, 784]) tensor.set_dims([1000, 784])
...@@ -23,8 +23,8 @@ class TestScope(unittest.TestCase): ...@@ -23,8 +23,8 @@ class TestScope(unittest.TestCase):
self.assertEqual(2.0, tensor_array_2[19, 11]) self.assertEqual(2.0, tensor_array_2[19, 11])
def test_float_tensor(self): def test_float_tensor(self):
scope = core.Scope(None) scope = core.Scope()
var = scope.create_var("test_tensor") var = scope.new_var("test_tensor")
tensor = var.get_tensor() tensor = var.get_tensor()
tensor.set_dims([1000, 784]) tensor.set_dims([1000, 784])
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册