提交 8c653ba7 编写于 作者: Y Yu Yang

Complete remove std::shared_ptr

上级 c7f25325
...@@ -15,6 +15,8 @@ ...@@ -15,6 +15,8 @@
#include "paddle/framework/backward.h" #include "paddle/framework/backward.h"
#include <list> #include <list>
#include <memory>
#include "paddle/framework/op_registry.h" #include "paddle/framework/op_registry.h"
#include "paddle/operators/net_op.h" #include "paddle/operators/net_op.h"
#include "paddle/operators/recurrent_op.h" #include "paddle/operators/recurrent_op.h"
...@@ -43,11 +45,11 @@ static bool AllInSet( ...@@ -43,11 +45,11 @@ static bool AllInSet(
return all_in_set; return all_in_set;
} }
static std::shared_ptr<OperatorBase> NOP() { static std::unique_ptr<OperatorBase> NOP() {
auto net_op = std::make_shared<operators::NetOp>(); auto net_op = new operators::NetOp();
net_op->SetType("@NOP@"); net_op->SetType("@NOP@");
net_op->CompleteAddOp(); net_op->CompleteAddOp();
return net_op; return std::unique_ptr<OperatorBase>(net_op);
} }
// Get backward operator from a forward operator, a recursive implementation. // Get backward operator from a forward operator, a recursive implementation.
...@@ -62,11 +64,7 @@ static std::shared_ptr<OperatorBase> NOP() { ...@@ -62,11 +64,7 @@ static std::shared_ptr<OperatorBase> NOP() {
// operator, in a complex situation, it maybe a NetOp. // operator, in a complex situation, it maybe a NetOp.
// //
// See Backward.h for details // See Backward.h for details
static std::shared_ptr<OperatorBase> BackwardRecursive( static std::unique_ptr<OperatorBase> BackwardRecursive(
const OperatorBase& forwardOp,
std::unordered_set<std::string>& no_grad_names, size_t& uniq_id);
std::shared_ptr<OperatorBase> BackwardRecursive(
const OperatorBase& forwardOp, const OperatorBase& forwardOp,
std::unordered_set<std::string>& no_grad_names, size_t& uniq_id) { std::unordered_set<std::string>& no_grad_names, size_t& uniq_id) {
// If all input gradients of forwarding operator do not need to calculate, // If all input gradients of forwarding operator do not need to calculate,
...@@ -91,7 +89,7 @@ std::shared_ptr<OperatorBase> BackwardRecursive( ...@@ -91,7 +89,7 @@ std::shared_ptr<OperatorBase> BackwardRecursive(
} }
// Returned gradient network // Returned gradient network
auto net = std::make_shared<operators::NetOp>(); auto net = std::unique_ptr<operators::NetOp>();
if (forwardOp.IsNetOp()) { if (forwardOp.IsNetOp()) {
// Because forwardOp is a net op, it can static_cast. // Because forwardOp is a net op, it can static_cast.
...@@ -105,14 +103,14 @@ std::shared_ptr<OperatorBase> BackwardRecursive( ...@@ -105,14 +103,14 @@ std::shared_ptr<OperatorBase> BackwardRecursive(
// reversely travel forwardNet and collect all duplicate outputs. // reversely travel forwardNet and collect all duplicate outputs.
for (auto it = forwardNet.ops_.rbegin(); it != forwardNet.ops_.rend(); for (auto it = forwardNet.ops_.rbegin(); it != forwardNet.ops_.rend();
++it, ++local_op_id) { ++it, ++local_op_id) {
auto fwd = *it; auto& fwd = *it;
auto bwd = BackwardRecursive(*fwd, no_grad_names, uniq_id); auto bwd = BackwardRecursive(*fwd, no_grad_names, uniq_id);
net->AddOp(bwd);
ForEachVarName(bwd->Outputs(), ForEachVarName(bwd->Outputs(),
[&dup_output_ops, local_op_id](const std::string& out) { [&dup_output_ops, local_op_id](const std::string& out) {
dup_output_ops[out].emplace_back(local_op_id); dup_output_ops[out].emplace_back(local_op_id);
return false; return false;
}); });
net->AddOp(std::move(bwd));
} }
// Get unique ID for this method. // Get unique ID for this method.
auto uid = uniq_id++; auto uid = uniq_id++;
...@@ -122,7 +120,7 @@ std::shared_ptr<OperatorBase> BackwardRecursive( ...@@ -122,7 +120,7 @@ std::shared_ptr<OperatorBase> BackwardRecursive(
// to handle this case. For each duplicate output, rename it to an alias // to handle this case. For each duplicate output, rename it to an alias
// (original name with a offset), append an `add` op for its operator, // (original name with a offset), append an `add` op for its operator,
// and finally sum all the alias variable to the final output variable y. // and finally sum all the alias variable to the final output variable y.
using Pos = std::pair<size_t, std::shared_ptr<OperatorBase>>; using Pos = std::pair<size_t, std::unique_ptr<OperatorBase>>;
std::list<Pos> insert_position; std::list<Pos> insert_position;
for (auto& dup_output_op : dup_output_ops) { for (auto& dup_output_op : dup_output_ops) {
const std::string& name = dup_output_op.first; const std::string& name = dup_output_op.first;
...@@ -150,13 +148,13 @@ std::shared_ptr<OperatorBase> BackwardRecursive( ...@@ -150,13 +148,13 @@ std::shared_ptr<OperatorBase> BackwardRecursive(
[](const Pos& l, const Pos& r) { return l.first > r.first; }); [](const Pos& l, const Pos& r) { return l.first > r.first; });
for (auto& pos : insert_position) { for (auto& pos : insert_position) {
net->InsertOp(pos.first + 1, pos.second); net->InsertOp(pos.first + 1, std::move(pos.second));
} }
} else { } else {
std::shared_ptr<OperatorBase> grad_op = OpRegistry::CreateGradOp(forwardOp); std::unique_ptr<OperatorBase> grad_op(OpRegistry::CreateGradOp(forwardOp));
ForEachVarName(grad_op->Inputs(), [&no_grad_names, &net, ForEachVarName(grad_op->Inputs(), [&no_grad_names, &net, &grad_op](
grad_op](const std::string& grad_input) { const std::string& grad_input) {
if (no_grad_names.count(grad_input)) { if (no_grad_names.count(grad_input)) {
// +1 for \0 // +1 for \0
std::string prefix = grad_input.substr( std::string prefix = grad_input.substr(
...@@ -190,20 +188,20 @@ std::shared_ptr<OperatorBase> BackwardRecursive( ...@@ -190,20 +188,20 @@ std::shared_ptr<OperatorBase> BackwardRecursive(
const auto& stepnet_op = const auto& stepnet_op =
*static_cast<const OperatorBase*>(&rnnop.stepnet()); *static_cast<const OperatorBase*>(&rnnop.stepnet());
// create stepnet's gradient op // create stepnet's gradient op
auto grad_stepnet = BackwardRecursive(stepnet_op, no_grad_names, uniq_id);
rnn_grad_op->set_stepnet( rnn_grad_op->set_stepnet(
std::static_pointer_cast<operators::NetOp>(grad_stepnet)); BackwardRecursive(stepnet_op, no_grad_names, uniq_id));
} }
if (net->ops_.empty()) { // Current no aux op is added to network if (net->ops_.empty()) { // Current no aux op is added to network
return grad_op; return grad_op;
} }
net->AddOp(grad_op); net->AddOp(std::move(grad_op));
} }
net->SetType("@GENERATED_BACKWARD@"); net->SetType("@GENERATED_BACKWARD@");
net->CompleteAddOp(); net->CompleteAddOp();
return net; return std::unique_ptr<OperatorBase>(
} // namespace framework static_cast<OperatorBase*>(net.release()));
}
// See header for comments // See header for comments
std::shared_ptr<OperatorBase> Backward( std::shared_ptr<OperatorBase> Backward(
......
...@@ -174,7 +174,7 @@ class OpRegistry { ...@@ -174,7 +174,7 @@ class OpRegistry {
} }
} }
static std::shared_ptr<OperatorBase> CreateOp(const std::string& type, static std::unique_ptr<OperatorBase> CreateOp(const std::string& type,
const VarNameMap& inputs, const VarNameMap& inputs,
const VarNameMap& outputs, const VarNameMap& outputs,
AttributeMap attrs) { AttributeMap attrs) {
...@@ -183,7 +183,7 @@ class OpRegistry { ...@@ -183,7 +183,7 @@ class OpRegistry {
"Operator '%s' has not been registered.", type); "Operator '%s' has not been registered.", type);
it->second.checker_->Check(attrs); it->second.checker_->Check(attrs);
auto op = it->second.creator_(type, inputs, outputs, attrs); auto op = it->second.creator_(type, inputs, outputs, attrs);
return std::shared_ptr<OperatorBase>(op); return std::unique_ptr<OperatorBase>(op);
} }
static VarNameMap ConvertOpDescVarsToVarNameMap( static VarNameMap ConvertOpDescVarsToVarNameMap(
...@@ -199,7 +199,7 @@ class OpRegistry { ...@@ -199,7 +199,7 @@ class OpRegistry {
return ret_val; return ret_val;
} }
static std::shared_ptr<OperatorBase> CreateOp(const OpDesc& op_desc) { static std::unique_ptr<OperatorBase> CreateOp(const OpDesc& op_desc) {
VarNameMap inputs = ConvertOpDescVarsToVarNameMap(op_desc.inputs()); VarNameMap inputs = ConvertOpDescVarsToVarNameMap(op_desc.inputs());
VarNameMap outputs = ConvertOpDescVarsToVarNameMap(op_desc.outputs()); VarNameMap outputs = ConvertOpDescVarsToVarNameMap(op_desc.outputs());
AttributeMap attrs; AttributeMap attrs;
...@@ -210,11 +210,10 @@ class OpRegistry { ...@@ -210,11 +210,10 @@ class OpRegistry {
return CreateOp(op_desc.type(), inputs, outputs, attrs); return CreateOp(op_desc.type(), inputs, outputs, attrs);
} }
static std::shared_ptr<OperatorBase> CreateGradOp(const OperatorBase& op) { static std::unique_ptr<OperatorBase> CreateGradOp(const OperatorBase& op) {
PADDLE_ENFORCE(!op.IsNetOp(), PADDLE_ENFORCE(!op.IsNetOp(),
"Use framework::Backward to get backward ops"); "Use framework::Backward to get backward ops");
std::shared_ptr<OperatorBase> grad_op(BuildGradOp(&op)); return std::unique_ptr<OperatorBase>(BuildGradOp(&op));
return grad_op;
} }
static std::unordered_map<std::string, const OpInfo>& op_info_map() { static std::unordered_map<std::string, const OpInfo>& op_info_map() {
......
...@@ -76,8 +76,7 @@ TEST(OpRegistry, CreateOp) { ...@@ -76,8 +76,7 @@ TEST(OpRegistry, CreateOp) {
attr->set_type(paddle::framework::AttrType::FLOAT); attr->set_type(paddle::framework::AttrType::FLOAT);
attr->set_f(scale); attr->set_f(scale);
std::shared_ptr<paddle::framework::OperatorBase> op = auto op = paddle::framework::OpRegistry::CreateOp(op_desc);
paddle::framework::OpRegistry::CreateOp(op_desc);
paddle::framework::Scope scope; paddle::framework::Scope scope;
paddle::platform::CPUDeviceContext dev_ctx; paddle::platform::CPUDeviceContext dev_ctx;
op->Run(scope, dev_ctx); op->Run(scope, dev_ctx);
...@@ -118,8 +117,7 @@ TEST(OpRegistry, DefaultValue) { ...@@ -118,8 +117,7 @@ TEST(OpRegistry, DefaultValue) {
ASSERT_TRUE(op_desc.IsInitialized()); ASSERT_TRUE(op_desc.IsInitialized());
std::shared_ptr<paddle::framework::OperatorBase> op = auto op = paddle::framework::OpRegistry::CreateOp(op_desc);
paddle::framework::OpRegistry::CreateOp(op_desc);
paddle::framework::Scope scope; paddle::framework::Scope scope;
paddle::platform::CPUDeviceContext dev_ctx; paddle::platform::CPUDeviceContext dev_ctx;
op->Run(scope, dev_ctx); op->Run(scope, dev_ctx);
......
...@@ -207,8 +207,7 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -207,8 +207,7 @@ All parameter, weight, gradient are variables in Paddle.
.def(py::init<>()) .def(py::init<>())
.def("__str__", string::to_string<const platform::CPUPlace &>); .def("__str__", string::to_string<const platform::CPUPlace &>);
py::class_<OperatorBase, std::shared_ptr<OperatorBase>> operator_base( py::class_<OperatorBase> operator_base(m, "Operator");
m, "Operator");
operator_base.def_static("create", [](py::bytes protobin) { operator_base.def_static("create", [](py::bytes protobin) {
OpDesc desc; OpDesc desc;
...@@ -228,25 +227,23 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -228,25 +227,23 @@ All parameter, weight, gradient are variables in Paddle.
ExposeOperator(operator_base); ExposeOperator(operator_base);
py::class_<operators::NetOp, std::shared_ptr<operators::NetOp>> net(m, "Net"); py::class_<operators::NetOp> net(m, "Net");
net.def_static("create", net.def_static("create",
[]() -> std::shared_ptr<operators::NetOp> { []() -> operators::NetOp * {
auto retv = std::make_shared<operators::NetOp>(); auto *retv = new operators::NetOp;
retv->SetType("plain_net"); retv->SetType("plain_net");
return retv; return retv;
}) })
.def("add_op", &operators::NetOp::AddOp) .def("add_op", [](operators::NetOp &self,
const OperatorBase &op) { self.AddOp(op); })
.def("add_op", .def("add_op",
[](operators::NetOp &self, [](operators::NetOp &self, const operators::NetOp &net) -> void {
const std::shared_ptr<operators::NetOp> &net) -> void { self.AddOp(net);
self.AddOp(std::static_pointer_cast<OperatorBase>(net));
}) })
.def("add_op", .def("add_op",
[](operators::NetOp &self, [](operators::NetOp &self,
const std::shared_ptr<operators::RecurrentOp> &rnn) -> void { const operators::RecurrentOp &rnn) -> void { self.AddOp(rnn); })
self.AddOp(std::static_pointer_cast<OperatorBase>(rnn));
})
.def("complete_add_op", &operators::NetOp::CompleteAddOp) .def("complete_add_op", &operators::NetOp::CompleteAddOp)
.def("complete_add_op", [](std::shared_ptr<operators::NetOp> &self) { .def("complete_add_op", [](std::shared_ptr<operators::NetOp> &self) {
self->CompleteAddOp(); self->CompleteAddOp();
...@@ -255,12 +252,11 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -255,12 +252,11 @@ All parameter, weight, gradient are variables in Paddle.
ExposeOperator(net); ExposeOperator(net);
// recurrent_op // recurrent_op
py::class_<operators::RecurrentOp, std::shared_ptr<operators::RecurrentOp>> py::class_<operators::RecurrentOp> rnn(m, "RecurrentOp");
rnn(m, "RecurrentOp");
rnn.def_static( rnn.def_static(
"create", "create",
[](py::bytes protobin) -> std::shared_ptr<operators::RecurrentOp> { [](py::bytes protobin) -> operators::RecurrentOp * {
OpDesc desc; OpDesc desc;
PADDLE_ENFORCE(desc.ParsePartialFromString(protobin), PADDLE_ENFORCE(desc.ParsePartialFromString(protobin),
"Cannot parse user input to OpDesc"); "Cannot parse user input to OpDesc");
...@@ -268,12 +264,11 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -268,12 +264,11 @@ All parameter, weight, gradient are variables in Paddle.
"User OpDesc is not initialized, reason %s", "User OpDesc is not initialized, reason %s",
desc.InitializationErrorString()); desc.InitializationErrorString());
auto rnn_op = OpRegistry::CreateOp(desc); auto rnn_op = OpRegistry::CreateOp(desc);
return std::dynamic_pointer_cast<operators::RecurrentOp>(rnn_op); return static_cast<operators::RecurrentOp *>(rnn_op.release());
}) })
.def("set_stepnet", .def("set_stepnet", [](operators::RecurrentOp &self,
[](operators::RecurrentOp &self, const operators::NetOp &net) -> void {
const std::shared_ptr<operators::NetOp> &net) -> void { self.set_stepnet(net.Clone());
self.set_stepnet(net);
}); });
ExposeOperator(rnn); ExposeOperator(rnn);
......
...@@ -45,10 +45,10 @@ class NetOp : public framework::OperatorBase { ...@@ -45,10 +45,10 @@ class NetOp : public framework::OperatorBase {
: framework::OperatorBase( : framework::OperatorBase(
static_cast<const framework::OperatorBase&>(o)) { static_cast<const framework::OperatorBase&>(o)) {
this->ops_.reserve(o.ops_.size()); this->ops_.reserve(o.ops_.size());
std::transform(o.ops_.begin(), o.ops_.end(), std::back_inserter(this->ops_), std::transform(
[](const std::shared_ptr<OperatorBase>& op) o.ops_.begin(), o.ops_.end(), std::back_inserter(this->ops_),
-> std::shared_ptr<OperatorBase> { [](const std::unique_ptr<framework::OperatorBase>& op) {
return std::shared_ptr<OperatorBase>(op->Clone()); return std::unique_ptr<framework::OperatorBase>(op->Clone());
}); });
this->CompleteAddOp(); this->CompleteAddOp();
} }
...@@ -86,21 +86,42 @@ class NetOp : public framework::OperatorBase { ...@@ -86,21 +86,42 @@ class NetOp : public framework::OperatorBase {
return true; return true;
} }
void AddOp(const framework::OperatorBase& op) { AddOp(op.Clone()); }
/** /**
* @brief Add an operator by ptr * @brief Add an operator by ptr
*/ */
void AddOp(const std::shared_ptr<OperatorBase>& op) { void AddOp(framework::OperatorBase* op, bool own) {
PADDLE_ENFORCE(!add_op_done_, "Cannot AddOp when this network is sealed"); PADDLE_ENFORCE(!add_op_done_, "Cannot AddOp when this network is sealed");
PADDLE_ENFORCE_NOT_NULL(op, "Cannot Insert Null op"); PADDLE_ENFORCE_NOT_NULL(op, "Cannot Insert Null op");
ops_.push_back(op); if (!own) {
op = op->Clone().release();
}
ops_.emplace_back(op);
} }
void InsertOp(size_t pos, const std::shared_ptr<OperatorBase>& op) { void AddOp(std::unique_ptr<framework::OperatorBase>&& op) {
AddOp(op.release(), true);
}
void InsertOp(size_t pos, framework::OperatorBase* op, bool own) {
PADDLE_ENFORCE(!add_op_done_, PADDLE_ENFORCE(!add_op_done_,
"Cannot InsertOp when this network is sealed"); "Cannot InsertOp when this network is sealed");
PADDLE_ENFORCE_NOT_NULL(op, "Cannot Insert Null op"); PADDLE_ENFORCE_NOT_NULL(op, "Cannot Insert Null op");
PADDLE_ENFORCE_LE(pos, ops_.size(), "Out of range"); PADDLE_ENFORCE_LE(pos, ops_.size(), "Out of range");
ops_.insert(ops_.begin() + pos, op); if (!own) {
op = op->Clone().release();
}
ops_.insert(ops_.begin() + pos,
std::unique_ptr<framework::OperatorBase>(op));
}
void InsertOp(size_t pos, std::unique_ptr<framework::OperatorBase>&& op) {
InsertOp(pos, op.release(), true);
}
void InsertOp(size_t pos, const framework::OperatorBase& op) {
InsertOp(pos, op.Clone());
} }
void CompleteAddOp(bool calculate = true); void CompleteAddOp(bool calculate = true);
...@@ -112,7 +133,7 @@ class NetOp : public framework::OperatorBase { ...@@ -112,7 +133,7 @@ class NetOp : public framework::OperatorBase {
std::unique_ptr<framework::OperatorBase> Clone() const override; std::unique_ptr<framework::OperatorBase> Clone() const override;
std::vector<std::shared_ptr<OperatorBase>> ops_; std::vector<std::unique_ptr<framework::OperatorBase>> ops_;
private: private:
bool add_op_done_{false}; bool add_op_done_{false};
......
...@@ -38,15 +38,12 @@ TEST(OpKernel, all) { ...@@ -38,15 +38,12 @@ TEST(OpKernel, all) {
auto net = std::make_shared<NetOp>(); auto net = std::make_shared<NetOp>();
ASSERT_NE(net, nullptr); ASSERT_NE(net, nullptr);
auto op1 = std::shared_ptr<TestOp>( net->AddOp(std::unique_ptr<TestOp>(
new TestOp("test", {{"X", {"x"}}, {"W", {"w1"}}, {"b", {"b1"}}}, new TestOp("test", {{"X", {"x"}}, {"W", {"w1"}}, {"b", {"b1"}}},
{{"Out", {"y"}}}, {})); {{"Out", {"y"}}}, {})));
net->AddOp(op1); net->AddOp(std::unique_ptr<TestOp>(
auto op2 = std::shared_ptr<TestOp>(
new TestOp("test", {{"X", {"y"}}, {"W", {"w2"}}, {"b", {"b2"}}}, new TestOp("test", {{"X", {"y"}}, {"W", {"w2"}}, {"b", {"b2"}}},
{{"Out", {"z"}}}, {})); {{"Out", {"z"}}}, {})));
net->AddOp(op2);
net->CompleteAddOp(); net->CompleteAddOp();
AssertSameVectorWithoutOrder({"x", "w1", "b1", "w2", "b2"}, AssertSameVectorWithoutOrder({"x", "w1", "b1", "w2", "b2"},
...@@ -61,21 +58,21 @@ TEST(OpKernel, all) { ...@@ -61,21 +58,21 @@ TEST(OpKernel, all) {
TEST(NetOp, insert_op) { TEST(NetOp, insert_op) {
NetOp net; NetOp net;
auto op1 = std::shared_ptr<framework::NOP>( auto op1 = std::unique_ptr<framework::NOP>(
new framework::NOP("empty", {{"X", {"x"}}, {"W", {"w1"}}, {"b", {"b1"}}}, new framework::NOP("empty", {{"X", {"x"}}, {"W", {"w1"}}, {"b", {"b1"}}},
{{"Out", {"y"}}}, {})); {{"Out", {"y"}}}, {}));
net.AddOp(op1); net.AddOp(*op1);
net.InsertOp(0, op1); net.InsertOp(0, *op1);
ASSERT_EQ(2UL, net.ops_.size()); ASSERT_EQ(2UL, net.ops_.size());
net.InsertOp(2, op1); net.InsertOp(2, std::move(op1));
ASSERT_EQ(3UL, net.ops_.size()); ASSERT_EQ(3UL, net.ops_.size());
} }
TEST(NetOp, Clone) { TEST(NetOp, Clone) {
NetOp net; NetOp net;
net.AddOp( net.AddOp(
std::shared_ptr<framework::NOP>(new framework::NOP{"empty", {}, {}, {}})); std::unique_ptr<framework::NOP>(new framework::NOP{"empty", {}, {}, {}}));
net.AddOp(std::shared_ptr<framework::NOP>( net.AddOp(std::unique_ptr<framework::NOP>(
new framework::NOP{"empty2", {}, {}, {}})); new framework::NOP{"empty2", {}, {}, {}}));
net.CompleteAddOp(true); net.CompleteAddOp(true);
auto new_net_op = net.Clone(); auto new_net_op = net.Clone();
......
...@@ -42,7 +42,7 @@ void RecurrentAlgorithm::InferShape(const Scope& scope) const { ...@@ -42,7 +42,7 @@ void RecurrentAlgorithm::InferShape(const Scope& scope) const {
rnn::LinkMemories(step_scopes, arg_->memories, i, -1, rnn::LinkMemories(step_scopes, arg_->memories, i, -1,
true /*infer_shape_mode*/); true /*infer_shape_mode*/);
} }
(*stepnet_)->InferShape(*step_scopes[i]); stepnet_->InferShape(*step_scopes[i]);
} }
rnn::ConcatOutputs(step_scopes, arg_->outlinks, seq_len_, rnn::ConcatOutputs(step_scopes, arg_->outlinks, seq_len_,
true /*infer_shape_mode*/); true /*infer_shape_mode*/);
...@@ -61,7 +61,7 @@ void RecurrentAlgorithm::Run(const Scope& scope, ...@@ -61,7 +61,7 @@ void RecurrentAlgorithm::Run(const Scope& scope,
rnn::LinkMemories(step_scopes, arg_->memories, step_id, -1, rnn::LinkMemories(step_scopes, arg_->memories, step_id, -1,
false /*infer_shape_mode*/); false /*infer_shape_mode*/);
} }
(*stepnet_)->Run(*step_scopes[step_id], dev_ctx); stepnet_->Run(*step_scopes[step_id], dev_ctx);
} }
rnn::ConcatOutputs(step_scopes, arg_->outlinks, seq_len_, rnn::ConcatOutputs(step_scopes, arg_->outlinks, seq_len_,
false /*infer_shape_mode*/); false /*infer_shape_mode*/);
...@@ -76,15 +76,15 @@ void RecurrentAlgorithm::CreateScopes(const Scope& scope) const { ...@@ -76,15 +76,15 @@ void RecurrentAlgorithm::CreateScopes(const Scope& scope) const {
// Now all variables in scope must be created outside of op. // Now all variables in scope must be created outside of op.
PADDLE_ENFORCE_NOT_NULL(stepnet_); PADDLE_ENFORCE_NOT_NULL(stepnet_);
PADDLE_ENFORCE(!(*stepnet_)->Outputs().empty(), "stepnet_ op has no outputs"); PADDLE_ENFORCE(!stepnet_->Outputs().empty(), "stepnet_ op has no outputs");
PADDLE_ENFORCE(!(*stepnet_)->Outputs().empty(), "net_op has no outputs"); PADDLE_ENFORCE(!stepnet_->Outputs().empty(), "net_op has no outputs");
if (seq_len_ > step_scopes->size()) { if (seq_len_ > step_scopes->size()) {
for (size_t i = step_scopes->size(); i < seq_len_; ++i) { for (size_t i = step_scopes->size(); i < seq_len_; ++i) {
auto& step_scope = scope.NewScope(); auto& step_scope = scope.NewScope();
// create step net's temp inputs // create step net's temp inputs
for (auto& input : (*stepnet_)->Inputs()) { for (auto& input : stepnet_->Inputs()) {
// the weight are located in parent scope // the weight are located in parent scope
for (auto& var_name : input.second) { for (auto& var_name : input.second) {
if (!step_scope.FindVar(var_name)) { if (!step_scope.FindVar(var_name)) {
...@@ -93,7 +93,7 @@ void RecurrentAlgorithm::CreateScopes(const Scope& scope) const { ...@@ -93,7 +93,7 @@ void RecurrentAlgorithm::CreateScopes(const Scope& scope) const {
} }
} }
// create stepnet's outputs // create stepnet's outputs
for (const auto& output : (*stepnet_)->Outputs()) { for (const auto& output : stepnet_->Outputs()) {
for (auto& var_name : output.second) { for (auto& var_name : output.second) {
step_scope.NewVar(var_name); step_scope.NewVar(var_name);
} }
...@@ -136,7 +136,7 @@ RecurrentOp::RecurrentOp(const std::string& type, ...@@ -136,7 +136,7 @@ RecurrentOp::RecurrentOp(const std::string& type,
const framework::AttributeMap& attrs) const framework::AttributeMap& attrs)
: OperatorBase(type, inputs, outputs, attrs) { : OperatorBase(type, inputs, outputs, attrs) {
rnn::InitArgument(kArgName, &arg_, *this); rnn::InitArgument(kArgName, &arg_, *this);
alg_.Init(&arg_, &stepnet_); alg_.Init(&arg_, stepnet_.get());
} }
class RecurrentAlgorithmProtoAndCheckerMaker class RecurrentAlgorithmProtoAndCheckerMaker
...@@ -178,7 +178,7 @@ void RecurrentGradientAlgorithm::Run( ...@@ -178,7 +178,7 @@ void RecurrentGradientAlgorithm::Run(
rnn::LinkMemories(step_scopes, arg_->memories, step_id, 1, rnn::LinkMemories(step_scopes, arg_->memories, step_id, 1,
false /*infer_shape_mode*/); false /*infer_shape_mode*/);
} }
(*stepnet_)->Run(*step_scopes[step_id], dev_ctx); stepnet_->Run(*step_scopes[step_id], dev_ctx);
} }
LinkBootMemoryGradients(step_scopes[0], false); LinkBootMemoryGradients(step_scopes[0], false);
rnn::ConcatOutputs(step_scopes, arg_->outlinks, seq_len_, rnn::ConcatOutputs(step_scopes, arg_->outlinks, seq_len_,
...@@ -215,7 +215,7 @@ void RecurrentGradientAlgorithm::InferShape(const Scope& scope) const { ...@@ -215,7 +215,7 @@ void RecurrentGradientAlgorithm::InferShape(const Scope& scope) const {
rnn::LinkMemories(step_scopes, arg_->memories, step_id, 1, rnn::LinkMemories(step_scopes, arg_->memories, step_id, 1,
true /*infer_shape_mode*/); true /*infer_shape_mode*/);
} }
(*stepnet_)->InferShape(*step_scopes[step_id]); stepnet_->InferShape(*step_scopes[step_id]);
} }
rnn::ConcatOutputs(step_scopes, arg_->outlinks, seq_len_, rnn::ConcatOutputs(step_scopes, arg_->outlinks, seq_len_,
true /*infer_shape_mode*/); true /*infer_shape_mode*/);
...@@ -228,7 +228,7 @@ RecurrentGradientOp::RecurrentGradientOp( ...@@ -228,7 +228,7 @@ RecurrentGradientOp::RecurrentGradientOp(
const framework::AttributeMap& attrs) const framework::AttributeMap& attrs)
: OperatorBase(type, inputs, outputs, attrs) { : OperatorBase(type, inputs, outputs, attrs) {
rnn::InitArgument(kArgName, &arg_, *this); rnn::InitArgument(kArgName, &arg_, *this);
alg_.Init(&arg_, &stepnet_); alg_.Init(&arg_, stepnet_.get());
} }
} // namespace operators } // namespace operators
......
...@@ -34,7 +34,7 @@ class RecurrentAlgorithm { ...@@ -34,7 +34,7 @@ class RecurrentAlgorithm {
void Run(const framework::Scope& scope, void Run(const framework::Scope& scope,
const platform::DeviceContext& dev_ctx) const; const platform::DeviceContext& dev_ctx) const;
void Init(rnn::Argument* arg, std::shared_ptr<NetOp>* stepnet) { void Init(rnn::Argument* arg, framework::OperatorBase* stepnet) {
PADDLE_ENFORCE_NOT_NULL(stepnet, "stepnet should be set before."); PADDLE_ENFORCE_NOT_NULL(stepnet, "stepnet should be set before.");
arg_ = arg; arg_ = arg;
stepnet_ = stepnet; stepnet_ = stepnet;
...@@ -63,7 +63,7 @@ class RecurrentAlgorithm { ...@@ -63,7 +63,7 @@ class RecurrentAlgorithm {
void InitMemories(framework::Scope* step_scopes, bool infer_shape_mode) const; void InitMemories(framework::Scope* step_scopes, bool infer_shape_mode) const;
private: private:
std::shared_ptr<NetOp>* stepnet_; framework::OperatorBase* stepnet_;
rnn::Argument* arg_; rnn::Argument* arg_;
mutable size_t seq_len_; mutable size_t seq_len_;
}; };
...@@ -80,7 +80,7 @@ class RecurrentGradientAlgorithm { ...@@ -80,7 +80,7 @@ class RecurrentGradientAlgorithm {
* operator. * operator.
*/ */
public: public:
void Init(rnn::Argument* arg, std::shared_ptr<NetOp>* stepnet) { void Init(rnn::Argument* arg, framework::OperatorBase* stepnet) {
PADDLE_ENFORCE_NOT_NULL(stepnet, "stepnet should be set before."); PADDLE_ENFORCE_NOT_NULL(stepnet, "stepnet should be set before.");
arg_ = std::move(arg); arg_ = std::move(arg);
stepnet_ = stepnet; stepnet_ = stepnet;
...@@ -107,7 +107,7 @@ class RecurrentGradientAlgorithm { ...@@ -107,7 +107,7 @@ class RecurrentGradientAlgorithm {
private: private:
rnn::Argument* arg_; rnn::Argument* arg_;
mutable size_t seq_len_; mutable size_t seq_len_;
std::shared_ptr<NetOp>* stepnet_; framework::OperatorBase* stepnet_;
}; };
class RecurrentOp : public framework::OperatorBase { class RecurrentOp : public framework::OperatorBase {
...@@ -133,15 +133,17 @@ class RecurrentOp : public framework::OperatorBase { ...@@ -133,15 +133,17 @@ class RecurrentOp : public framework::OperatorBase {
alg_.Run(scope, dev_ctx); alg_.Run(scope, dev_ctx);
} }
void set_stepnet(std::shared_ptr<NetOp> net) { stepnet_ = net; } void set_stepnet(std::unique_ptr<OperatorBase> net) {
const NetOp& stepnet() const { return *stepnet_; } stepnet_ = std::move(net);
}
const OperatorBase& stepnet() const { return *stepnet_; }
static const rnn::ArgumentName kArgName; static const rnn::ArgumentName kArgName;
private: private:
RecurrentAlgorithm alg_; RecurrentAlgorithm alg_;
rnn::Argument arg_; rnn::Argument arg_;
std::shared_ptr<NetOp> stepnet_; std::unique_ptr<OperatorBase> stepnet_;
}; };
class RecurrentGradientOp : public framework::OperatorBase { class RecurrentGradientOp : public framework::OperatorBase {
...@@ -171,12 +173,14 @@ class RecurrentGradientOp : public framework::OperatorBase { ...@@ -171,12 +173,14 @@ class RecurrentGradientOp : public framework::OperatorBase {
static const rnn::ArgumentName kArgName; static const rnn::ArgumentName kArgName;
void set_stepnet(const std::shared_ptr<NetOp>& net) { stepnet_ = net; } void set_stepnet(std::unique_ptr<OperatorBase> net) {
const NetOp& stepnet() const { return *stepnet_; } stepnet_ = std::move(net);
}
const OperatorBase& stepnet() const { return *stepnet_; }
private: private:
RecurrentGradientAlgorithm alg_; RecurrentGradientAlgorithm alg_;
std::shared_ptr<NetOp> stepnet_; std::unique_ptr<OperatorBase> stepnet_;
rnn::Argument arg_; rnn::Argument arg_;
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册