提交 8c653ba7 编写于 作者: Y Yu Yang

Complete remove std::shared_ptr

上级 c7f25325
......@@ -15,6 +15,8 @@
#include "paddle/framework/backward.h"
#include <list>
#include <memory>
#include "paddle/framework/op_registry.h"
#include "paddle/operators/net_op.h"
#include "paddle/operators/recurrent_op.h"
......@@ -43,11 +45,11 @@ static bool AllInSet(
return all_in_set;
}
static std::shared_ptr<OperatorBase> NOP() {
auto net_op = std::make_shared<operators::NetOp>();
static std::unique_ptr<OperatorBase> NOP() {
auto net_op = new operators::NetOp();
net_op->SetType("@NOP@");
net_op->CompleteAddOp();
return net_op;
return std::unique_ptr<OperatorBase>(net_op);
}
// Get backward operator from a forward operator, a recursive implementation.
......@@ -62,11 +64,7 @@ static std::shared_ptr<OperatorBase> NOP() {
// operator, in a complex situation, it maybe a NetOp.
//
// See Backward.h for details
static std::shared_ptr<OperatorBase> BackwardRecursive(
const OperatorBase& forwardOp,
std::unordered_set<std::string>& no_grad_names, size_t& uniq_id);
std::shared_ptr<OperatorBase> BackwardRecursive(
static std::unique_ptr<OperatorBase> BackwardRecursive(
const OperatorBase& forwardOp,
std::unordered_set<std::string>& no_grad_names, size_t& uniq_id) {
// If all input gradients of forwarding operator do not need to calculate,
......@@ -91,7 +89,7 @@ std::shared_ptr<OperatorBase> BackwardRecursive(
}
// Returned gradient network
auto net = std::make_shared<operators::NetOp>();
auto net = std::unique_ptr<operators::NetOp>();
if (forwardOp.IsNetOp()) {
// Because forwardOp is a net op, it can static_cast.
......@@ -105,14 +103,14 @@ std::shared_ptr<OperatorBase> BackwardRecursive(
// reversely travel forwardNet and collect all duplicate outputs.
for (auto it = forwardNet.ops_.rbegin(); it != forwardNet.ops_.rend();
++it, ++local_op_id) {
auto fwd = *it;
auto& fwd = *it;
auto bwd = BackwardRecursive(*fwd, no_grad_names, uniq_id);
net->AddOp(bwd);
ForEachVarName(bwd->Outputs(),
[&dup_output_ops, local_op_id](const std::string& out) {
dup_output_ops[out].emplace_back(local_op_id);
return false;
});
net->AddOp(std::move(bwd));
}
// Get unique ID for this method.
auto uid = uniq_id++;
......@@ -122,7 +120,7 @@ std::shared_ptr<OperatorBase> BackwardRecursive(
// to handle this case. For each duplicate output, rename it to an alias
// (original name with a offset), append an `add` op for its operator,
// and finally sum all the alias variable to the final output variable y.
using Pos = std::pair<size_t, std::shared_ptr<OperatorBase>>;
using Pos = std::pair<size_t, std::unique_ptr<OperatorBase>>;
std::list<Pos> insert_position;
for (auto& dup_output_op : dup_output_ops) {
const std::string& name = dup_output_op.first;
......@@ -150,13 +148,13 @@ std::shared_ptr<OperatorBase> BackwardRecursive(
[](const Pos& l, const Pos& r) { return l.first > r.first; });
for (auto& pos : insert_position) {
net->InsertOp(pos.first + 1, pos.second);
net->InsertOp(pos.first + 1, std::move(pos.second));
}
} else {
std::shared_ptr<OperatorBase> grad_op = OpRegistry::CreateGradOp(forwardOp);
std::unique_ptr<OperatorBase> grad_op(OpRegistry::CreateGradOp(forwardOp));
ForEachVarName(grad_op->Inputs(), [&no_grad_names, &net,
grad_op](const std::string& grad_input) {
ForEachVarName(grad_op->Inputs(), [&no_grad_names, &net, &grad_op](
const std::string& grad_input) {
if (no_grad_names.count(grad_input)) {
// +1 for \0
std::string prefix = grad_input.substr(
......@@ -190,20 +188,20 @@ std::shared_ptr<OperatorBase> BackwardRecursive(
const auto& stepnet_op =
*static_cast<const OperatorBase*>(&rnnop.stepnet());
// create stepnet's gradient op
auto grad_stepnet = BackwardRecursive(stepnet_op, no_grad_names, uniq_id);
rnn_grad_op->set_stepnet(
std::static_pointer_cast<operators::NetOp>(grad_stepnet));
BackwardRecursive(stepnet_op, no_grad_names, uniq_id));
}
if (net->ops_.empty()) { // Current no aux op is added to network
return grad_op;
}
net->AddOp(grad_op);
net->AddOp(std::move(grad_op));
}
net->SetType("@GENERATED_BACKWARD@");
net->CompleteAddOp();
return net;
} // namespace framework
return std::unique_ptr<OperatorBase>(
static_cast<OperatorBase*>(net.release()));
}
// See header for comments
std::shared_ptr<OperatorBase> Backward(
......
......@@ -174,7 +174,7 @@ class OpRegistry {
}
}
static std::shared_ptr<OperatorBase> CreateOp(const std::string& type,
static std::unique_ptr<OperatorBase> CreateOp(const std::string& type,
const VarNameMap& inputs,
const VarNameMap& outputs,
AttributeMap attrs) {
......@@ -183,7 +183,7 @@ class OpRegistry {
"Operator '%s' has not been registered.", type);
it->second.checker_->Check(attrs);
auto op = it->second.creator_(type, inputs, outputs, attrs);
return std::shared_ptr<OperatorBase>(op);
return std::unique_ptr<OperatorBase>(op);
}
static VarNameMap ConvertOpDescVarsToVarNameMap(
......@@ -199,7 +199,7 @@ class OpRegistry {
return ret_val;
}
static std::shared_ptr<OperatorBase> CreateOp(const OpDesc& op_desc) {
static std::unique_ptr<OperatorBase> CreateOp(const OpDesc& op_desc) {
VarNameMap inputs = ConvertOpDescVarsToVarNameMap(op_desc.inputs());
VarNameMap outputs = ConvertOpDescVarsToVarNameMap(op_desc.outputs());
AttributeMap attrs;
......@@ -210,11 +210,10 @@ class OpRegistry {
return CreateOp(op_desc.type(), inputs, outputs, attrs);
}
static std::shared_ptr<OperatorBase> CreateGradOp(const OperatorBase& op) {
static std::unique_ptr<OperatorBase> CreateGradOp(const OperatorBase& op) {
PADDLE_ENFORCE(!op.IsNetOp(),
"Use framework::Backward to get backward ops");
std::shared_ptr<OperatorBase> grad_op(BuildGradOp(&op));
return grad_op;
return std::unique_ptr<OperatorBase>(BuildGradOp(&op));
}
static std::unordered_map<std::string, const OpInfo>& op_info_map() {
......
......@@ -76,8 +76,7 @@ TEST(OpRegistry, CreateOp) {
attr->set_type(paddle::framework::AttrType::FLOAT);
attr->set_f(scale);
std::shared_ptr<paddle::framework::OperatorBase> op =
paddle::framework::OpRegistry::CreateOp(op_desc);
auto op = paddle::framework::OpRegistry::CreateOp(op_desc);
paddle::framework::Scope scope;
paddle::platform::CPUDeviceContext dev_ctx;
op->Run(scope, dev_ctx);
......@@ -118,8 +117,7 @@ TEST(OpRegistry, DefaultValue) {
ASSERT_TRUE(op_desc.IsInitialized());
std::shared_ptr<paddle::framework::OperatorBase> op =
paddle::framework::OpRegistry::CreateOp(op_desc);
auto op = paddle::framework::OpRegistry::CreateOp(op_desc);
paddle::framework::Scope scope;
paddle::platform::CPUDeviceContext dev_ctx;
op->Run(scope, dev_ctx);
......
......@@ -207,8 +207,7 @@ All parameter, weight, gradient are variables in Paddle.
.def(py::init<>())
.def("__str__", string::to_string<const platform::CPUPlace &>);
py::class_<OperatorBase, std::shared_ptr<OperatorBase>> operator_base(
m, "Operator");
py::class_<OperatorBase> operator_base(m, "Operator");
operator_base.def_static("create", [](py::bytes protobin) {
OpDesc desc;
......@@ -228,25 +227,23 @@ All parameter, weight, gradient are variables in Paddle.
ExposeOperator(operator_base);
py::class_<operators::NetOp, std::shared_ptr<operators::NetOp>> net(m, "Net");
py::class_<operators::NetOp> net(m, "Net");
net.def_static("create",
[]() -> std::shared_ptr<operators::NetOp> {
auto retv = std::make_shared<operators::NetOp>();
[]() -> operators::NetOp * {
auto *retv = new operators::NetOp;
retv->SetType("plain_net");
return retv;
})
.def("add_op", &operators::NetOp::AddOp)
.def("add_op", [](operators::NetOp &self,
const OperatorBase &op) { self.AddOp(op); })
.def("add_op",
[](operators::NetOp &self,
const std::shared_ptr<operators::NetOp> &net) -> void {
self.AddOp(std::static_pointer_cast<OperatorBase>(net));
[](operators::NetOp &self, const operators::NetOp &net) -> void {
self.AddOp(net);
})
.def("add_op",
[](operators::NetOp &self,
const std::shared_ptr<operators::RecurrentOp> &rnn) -> void {
self.AddOp(std::static_pointer_cast<OperatorBase>(rnn));
})
const operators::RecurrentOp &rnn) -> void { self.AddOp(rnn); })
.def("complete_add_op", &operators::NetOp::CompleteAddOp)
.def("complete_add_op", [](std::shared_ptr<operators::NetOp> &self) {
self->CompleteAddOp();
......@@ -255,12 +252,11 @@ All parameter, weight, gradient are variables in Paddle.
ExposeOperator(net);
// recurrent_op
py::class_<operators::RecurrentOp, std::shared_ptr<operators::RecurrentOp>>
rnn(m, "RecurrentOp");
py::class_<operators::RecurrentOp> rnn(m, "RecurrentOp");
rnn.def_static(
"create",
[](py::bytes protobin) -> std::shared_ptr<operators::RecurrentOp> {
[](py::bytes protobin) -> operators::RecurrentOp * {
OpDesc desc;
PADDLE_ENFORCE(desc.ParsePartialFromString(protobin),
"Cannot parse user input to OpDesc");
......@@ -268,13 +264,12 @@ All parameter, weight, gradient are variables in Paddle.
"User OpDesc is not initialized, reason %s",
desc.InitializationErrorString());
auto rnn_op = OpRegistry::CreateOp(desc);
return std::dynamic_pointer_cast<operators::RecurrentOp>(rnn_op);
return static_cast<operators::RecurrentOp *>(rnn_op.release());
})
.def("set_stepnet",
[](operators::RecurrentOp &self,
const std::shared_ptr<operators::NetOp> &net) -> void {
self.set_stepnet(net);
});
.def("set_stepnet", [](operators::RecurrentOp &self,
const operators::NetOp &net) -> void {
self.set_stepnet(net.Clone());
});
ExposeOperator(rnn);
m.def("unique_integer", UniqueIntegerGenerator);
......
......@@ -45,11 +45,11 @@ class NetOp : public framework::OperatorBase {
: framework::OperatorBase(
static_cast<const framework::OperatorBase&>(o)) {
this->ops_.reserve(o.ops_.size());
std::transform(o.ops_.begin(), o.ops_.end(), std::back_inserter(this->ops_),
[](const std::shared_ptr<OperatorBase>& op)
-> std::shared_ptr<OperatorBase> {
return std::shared_ptr<OperatorBase>(op->Clone());
});
std::transform(
o.ops_.begin(), o.ops_.end(), std::back_inserter(this->ops_),
[](const std::unique_ptr<framework::OperatorBase>& op) {
return std::unique_ptr<framework::OperatorBase>(op->Clone());
});
this->CompleteAddOp();
}
......@@ -86,21 +86,42 @@ class NetOp : public framework::OperatorBase {
return true;
}
void AddOp(const framework::OperatorBase& op) { AddOp(op.Clone()); }
/**
* @brief Add an operator by ptr
*/
void AddOp(const std::shared_ptr<OperatorBase>& op) {
void AddOp(framework::OperatorBase* op, bool own) {
PADDLE_ENFORCE(!add_op_done_, "Cannot AddOp when this network is sealed");
PADDLE_ENFORCE_NOT_NULL(op, "Cannot Insert Null op");
ops_.push_back(op);
if (!own) {
op = op->Clone().release();
}
ops_.emplace_back(op);
}
void InsertOp(size_t pos, const std::shared_ptr<OperatorBase>& op) {
void AddOp(std::unique_ptr<framework::OperatorBase>&& op) {
AddOp(op.release(), true);
}
void InsertOp(size_t pos, framework::OperatorBase* op, bool own) {
PADDLE_ENFORCE(!add_op_done_,
"Cannot InsertOp when this network is sealed");
PADDLE_ENFORCE_NOT_NULL(op, "Cannot Insert Null op");
PADDLE_ENFORCE_LE(pos, ops_.size(), "Out of range");
ops_.insert(ops_.begin() + pos, op);
if (!own) {
op = op->Clone().release();
}
ops_.insert(ops_.begin() + pos,
std::unique_ptr<framework::OperatorBase>(op));
}
void InsertOp(size_t pos, std::unique_ptr<framework::OperatorBase>&& op) {
InsertOp(pos, op.release(), true);
}
void InsertOp(size_t pos, const framework::OperatorBase& op) {
InsertOp(pos, op.Clone());
}
void CompleteAddOp(bool calculate = true);
......@@ -112,7 +133,7 @@ class NetOp : public framework::OperatorBase {
std::unique_ptr<framework::OperatorBase> Clone() const override;
std::vector<std::shared_ptr<OperatorBase>> ops_;
std::vector<std::unique_ptr<framework::OperatorBase>> ops_;
private:
bool add_op_done_{false};
......
......@@ -38,15 +38,12 @@ TEST(OpKernel, all) {
auto net = std::make_shared<NetOp>();
ASSERT_NE(net, nullptr);
auto op1 = std::shared_ptr<TestOp>(
net->AddOp(std::unique_ptr<TestOp>(
new TestOp("test", {{"X", {"x"}}, {"W", {"w1"}}, {"b", {"b1"}}},
{{"Out", {"y"}}}, {}));
net->AddOp(op1);
auto op2 = std::shared_ptr<TestOp>(
{{"Out", {"y"}}}, {})));
net->AddOp(std::unique_ptr<TestOp>(
new TestOp("test", {{"X", {"y"}}, {"W", {"w2"}}, {"b", {"b2"}}},
{{"Out", {"z"}}}, {}));
net->AddOp(op2);
{{"Out", {"z"}}}, {})));
net->CompleteAddOp();
AssertSameVectorWithoutOrder({"x", "w1", "b1", "w2", "b2"},
......@@ -61,21 +58,21 @@ TEST(OpKernel, all) {
TEST(NetOp, insert_op) {
NetOp net;
auto op1 = std::shared_ptr<framework::NOP>(
auto op1 = std::unique_ptr<framework::NOP>(
new framework::NOP("empty", {{"X", {"x"}}, {"W", {"w1"}}, {"b", {"b1"}}},
{{"Out", {"y"}}}, {}));
net.AddOp(op1);
net.InsertOp(0, op1);
net.AddOp(*op1);
net.InsertOp(0, *op1);
ASSERT_EQ(2UL, net.ops_.size());
net.InsertOp(2, op1);
net.InsertOp(2, std::move(op1));
ASSERT_EQ(3UL, net.ops_.size());
}
TEST(NetOp, Clone) {
NetOp net;
net.AddOp(
std::shared_ptr<framework::NOP>(new framework::NOP{"empty", {}, {}, {}}));
net.AddOp(std::shared_ptr<framework::NOP>(
std::unique_ptr<framework::NOP>(new framework::NOP{"empty", {}, {}, {}}));
net.AddOp(std::unique_ptr<framework::NOP>(
new framework::NOP{"empty2", {}, {}, {}}));
net.CompleteAddOp(true);
auto new_net_op = net.Clone();
......
......@@ -42,7 +42,7 @@ void RecurrentAlgorithm::InferShape(const Scope& scope) const {
rnn::LinkMemories(step_scopes, arg_->memories, i, -1,
true /*infer_shape_mode*/);
}
(*stepnet_)->InferShape(*step_scopes[i]);
stepnet_->InferShape(*step_scopes[i]);
}
rnn::ConcatOutputs(step_scopes, arg_->outlinks, seq_len_,
true /*infer_shape_mode*/);
......@@ -61,7 +61,7 @@ void RecurrentAlgorithm::Run(const Scope& scope,
rnn::LinkMemories(step_scopes, arg_->memories, step_id, -1,
false /*infer_shape_mode*/);
}
(*stepnet_)->Run(*step_scopes[step_id], dev_ctx);
stepnet_->Run(*step_scopes[step_id], dev_ctx);
}
rnn::ConcatOutputs(step_scopes, arg_->outlinks, seq_len_,
false /*infer_shape_mode*/);
......@@ -76,15 +76,15 @@ void RecurrentAlgorithm::CreateScopes(const Scope& scope) const {
// Now all variables in scope must be created outside of op.
PADDLE_ENFORCE_NOT_NULL(stepnet_);
PADDLE_ENFORCE(!(*stepnet_)->Outputs().empty(), "stepnet_ op has no outputs");
PADDLE_ENFORCE(!(*stepnet_)->Outputs().empty(), "net_op has no outputs");
PADDLE_ENFORCE(!stepnet_->Outputs().empty(), "stepnet_ op has no outputs");
PADDLE_ENFORCE(!stepnet_->Outputs().empty(), "net_op has no outputs");
if (seq_len_ > step_scopes->size()) {
for (size_t i = step_scopes->size(); i < seq_len_; ++i) {
auto& step_scope = scope.NewScope();
// create step net's temp inputs
for (auto& input : (*stepnet_)->Inputs()) {
for (auto& input : stepnet_->Inputs()) {
// the weight are located in parent scope
for (auto& var_name : input.second) {
if (!step_scope.FindVar(var_name)) {
......@@ -93,7 +93,7 @@ void RecurrentAlgorithm::CreateScopes(const Scope& scope) const {
}
}
// create stepnet's outputs
for (const auto& output : (*stepnet_)->Outputs()) {
for (const auto& output : stepnet_->Outputs()) {
for (auto& var_name : output.second) {
step_scope.NewVar(var_name);
}
......@@ -136,7 +136,7 @@ RecurrentOp::RecurrentOp(const std::string& type,
const framework::AttributeMap& attrs)
: OperatorBase(type, inputs, outputs, attrs) {
rnn::InitArgument(kArgName, &arg_, *this);
alg_.Init(&arg_, &stepnet_);
alg_.Init(&arg_, stepnet_.get());
}
class RecurrentAlgorithmProtoAndCheckerMaker
......@@ -178,7 +178,7 @@ void RecurrentGradientAlgorithm::Run(
rnn::LinkMemories(step_scopes, arg_->memories, step_id, 1,
false /*infer_shape_mode*/);
}
(*stepnet_)->Run(*step_scopes[step_id], dev_ctx);
stepnet_->Run(*step_scopes[step_id], dev_ctx);
}
LinkBootMemoryGradients(step_scopes[0], false);
rnn::ConcatOutputs(step_scopes, arg_->outlinks, seq_len_,
......@@ -215,7 +215,7 @@ void RecurrentGradientAlgorithm::InferShape(const Scope& scope) const {
rnn::LinkMemories(step_scopes, arg_->memories, step_id, 1,
true /*infer_shape_mode*/);
}
(*stepnet_)->InferShape(*step_scopes[step_id]);
stepnet_->InferShape(*step_scopes[step_id]);
}
rnn::ConcatOutputs(step_scopes, arg_->outlinks, seq_len_,
true /*infer_shape_mode*/);
......@@ -228,7 +228,7 @@ RecurrentGradientOp::RecurrentGradientOp(
const framework::AttributeMap& attrs)
: OperatorBase(type, inputs, outputs, attrs) {
rnn::InitArgument(kArgName, &arg_, *this);
alg_.Init(&arg_, &stepnet_);
alg_.Init(&arg_, stepnet_.get());
}
} // namespace operators
......
......@@ -34,7 +34,7 @@ class RecurrentAlgorithm {
void Run(const framework::Scope& scope,
const platform::DeviceContext& dev_ctx) const;
void Init(rnn::Argument* arg, std::shared_ptr<NetOp>* stepnet) {
void Init(rnn::Argument* arg, framework::OperatorBase* stepnet) {
PADDLE_ENFORCE_NOT_NULL(stepnet, "stepnet should be set before.");
arg_ = arg;
stepnet_ = stepnet;
......@@ -63,7 +63,7 @@ class RecurrentAlgorithm {
void InitMemories(framework::Scope* step_scopes, bool infer_shape_mode) const;
private:
std::shared_ptr<NetOp>* stepnet_;
framework::OperatorBase* stepnet_;
rnn::Argument* arg_;
mutable size_t seq_len_;
};
......@@ -80,7 +80,7 @@ class RecurrentGradientAlgorithm {
* operator.
*/
public:
void Init(rnn::Argument* arg, std::shared_ptr<NetOp>* stepnet) {
void Init(rnn::Argument* arg, framework::OperatorBase* stepnet) {
PADDLE_ENFORCE_NOT_NULL(stepnet, "stepnet should be set before.");
arg_ = std::move(arg);
stepnet_ = stepnet;
......@@ -107,7 +107,7 @@ class RecurrentGradientAlgorithm {
private:
rnn::Argument* arg_;
mutable size_t seq_len_;
std::shared_ptr<NetOp>* stepnet_;
framework::OperatorBase* stepnet_;
};
class RecurrentOp : public framework::OperatorBase {
......@@ -133,15 +133,17 @@ class RecurrentOp : public framework::OperatorBase {
alg_.Run(scope, dev_ctx);
}
void set_stepnet(std::shared_ptr<NetOp> net) { stepnet_ = net; }
const NetOp& stepnet() const { return *stepnet_; }
void set_stepnet(std::unique_ptr<OperatorBase> net) {
stepnet_ = std::move(net);
}
const OperatorBase& stepnet() const { return *stepnet_; }
static const rnn::ArgumentName kArgName;
private:
RecurrentAlgorithm alg_;
rnn::Argument arg_;
std::shared_ptr<NetOp> stepnet_;
std::unique_ptr<OperatorBase> stepnet_;
};
class RecurrentGradientOp : public framework::OperatorBase {
......@@ -171,12 +173,14 @@ class RecurrentGradientOp : public framework::OperatorBase {
static const rnn::ArgumentName kArgName;
void set_stepnet(const std::shared_ptr<NetOp>& net) { stepnet_ = net; }
const NetOp& stepnet() const { return *stepnet_; }
void set_stepnet(std::unique_ptr<OperatorBase> net) {
stepnet_ = std::move(net);
}
const OperatorBase& stepnet() const { return *stepnet_; }
private:
RecurrentGradientAlgorithm alg_;
std::shared_ptr<NetOp> stepnet_;
std::unique_ptr<OperatorBase> stepnet_;
rnn::Argument arg_;
};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册