提交 f15e0830 编写于 作者: Y Yu Yang

Remove std::shared_ptr in Python & C++

* Also simplify pybind implementation by using OperatorBase as holder
  type.
上级 7f5338a7
......@@ -89,7 +89,7 @@ static std::unique_ptr<OperatorBase> BackwardRecursive(
}
// Returned gradient network
auto net = std::unique_ptr<operators::NetOp>();
auto net = std::unique_ptr<operators::NetOp>(new operators::NetOp());
if (forwardOp.IsNetOp()) {
// Because forwardOp is a net op, it can static_cast.
......@@ -204,7 +204,7 @@ static std::unique_ptr<OperatorBase> BackwardRecursive(
}
// See header for comments
std::shared_ptr<OperatorBase> Backward(
std::unique_ptr<OperatorBase> Backward(
const OperatorBase& forwardOp,
const std::unordered_set<std::string>& no_grad_vars) {
std::unordered_set<std::string> no_grad_names;
......
......@@ -20,7 +20,7 @@ namespace framework {
// Create the backward operator from a forward operator.
// TODO(yuyang18): Add more API reference comment.
extern std::shared_ptr<OperatorBase> Backward(
extern std::unique_ptr<OperatorBase> Backward(
const OperatorBase& forwardOp,
const std::unordered_set<std::string>& no_grad_vars);
} // namespace framework
......
......@@ -180,8 +180,7 @@ TEST(Backward, simple_op_not_need_grad) {
auto no_input_gop = f::Backward(*fwd, {"x", "b"});
ASSERT_NE(no_input_gop, nullptr);
ASSERT_TRUE(no_input_gop->IsNetOp());
ASSERT_EQ(0UL,
std::static_pointer_cast<ops::NetOp>(no_input_gop)->ops_.size());
ASSERT_EQ(0UL, static_cast<ops::NetOp *>(no_input_gop.get())->ops_.size());
}
TEST(Backward, net_fc_backward_normal) {
......
......@@ -48,29 +48,6 @@ namespace framework {
using Tensor = framework::Tensor;
template <typename ClassType>
void ExposeOperator(ClassType &m) {
m.def("infer_shape", &ClassType::type::InferShape)
.def("run", &ClassType::type::Run)
.def("type",
[](const typename ClassType::type &op) -> std::string {
return op.Type();
})
.def("outputs",
[](const typename ClassType::type &op)
-> std::map<std::string, std::vector<std::string>> {
return op.Outputs();
})
.def("inputs",
[](const typename ClassType::type &op) { return op.Inputs(); })
.def("__str__", &ClassType::type::DebugString)
.def("no_intermediate_outputs",
[](const typename ClassType::type &op) {
return op.OutputVars(false);
})
.def("support_gpu", &ClassType::type::SupportGPU);
}
static size_t UniqueIntegerGenerator() {
static std::atomic<size_t> generator;
return generator.fetch_add(1);
......@@ -207,70 +184,69 @@ All parameter, weight, gradient are variables in Paddle.
.def(py::init<>())
.def("__str__", string::to_string<const platform::CPUPlace &>);
py::class_<OperatorBase> operator_base(m, "Operator");
operator_base.def_static("create", [](py::bytes protobin) {
OpDesc desc;
PADDLE_ENFORCE(desc.ParsePartialFromString(protobin),
"Cannot parse user input to OpDesc");
PADDLE_ENFORCE(desc.IsInitialized(),
"User OpDesc is not initialized, reason %s",
desc.InitializationErrorString());
return OpRegistry::CreateOp(desc);
});
operator_base.def("backward",
[](const OperatorBase &forwardOp,
const std::unordered_set<std::string> &no_grad_vars) {
return Backward(forwardOp, no_grad_vars);
});
ExposeOperator(operator_base);
py::class_<operators::NetOp> net(m, "Net");
py::class_<OperatorBase>(m, "Operator")
.def_static("create",
[](py::bytes protobin) {
OpDesc desc;
PADDLE_ENFORCE(desc.ParsePartialFromString(protobin),
"Cannot parse user input to OpDesc");
PADDLE_ENFORCE(desc.IsInitialized(),
"User OpDesc is not initialized, reason %s",
desc.InitializationErrorString());
return OpRegistry::CreateOp(desc);
})
.def("backward",
[](const OperatorBase &forwardOp,
const std::unordered_set<std::string> &no_grad_vars) {
return Backward(forwardOp, no_grad_vars).release();
})
.def("infer_shape", &OperatorBase::InferShape)
.def("run", &OperatorBase::Run)
.def("type",
[](const OperatorBase &op) -> std::string { return op.Type(); })
.def("outputs",
[](const OperatorBase &op)
-> std::map<std::string, std::vector<std::string>> {
return op.Outputs();
})
.def("inputs", [](const OperatorBase &op) { return op.Inputs(); })
.def("__str__", &OperatorBase::DebugString)
.def("no_intermediate_outputs",
[](const OperatorBase &op) { return op.OutputVars(false); })
.def("support_gpu", &OperatorBase::SupportGPU);
net.def_static("create",
[]() -> operators::NetOp * {
auto *retv = new operators::NetOp;
retv->SetType("plain_net");
return retv;
})
py::class_<operators::NetOp, OperatorBase>(m, "Net")
.def_static("create",
[]() -> operators::NetOp * {
auto *retv = new operators::NetOp;
retv->SetType("plain_net");
return retv;
})
.def("add_op", [](operators::NetOp &self,
const OperatorBase &op) { self.AddOp(op); })
.def("add_op",
[](operators::NetOp &self, const operators::NetOp &net) -> void {
self.AddOp(net);
})
.def("add_op",
[](operators::NetOp &self,
const operators::RecurrentOp &rnn) -> void { self.AddOp(rnn); })
.def("complete_add_op", &operators::NetOp::CompleteAddOp)
.def("complete_add_op", [](std::shared_ptr<operators::NetOp> &self) {
self->CompleteAddOp();
});
ExposeOperator(net);
// recurrent_op
py::class_<operators::RecurrentOp> rnn(m, "RecurrentOp");
rnn.def_static(
"create",
[](py::bytes protobin) -> operators::RecurrentOp * {
OpDesc desc;
PADDLE_ENFORCE(desc.ParsePartialFromString(protobin),
"Cannot parse user input to OpDesc");
PADDLE_ENFORCE(desc.IsInitialized(),
"User OpDesc is not initialized, reason %s",
desc.InitializationErrorString());
auto rnn_op = OpRegistry::CreateOp(desc);
return static_cast<operators::RecurrentOp *>(rnn_op.release());
})
py::class_<operators::RecurrentOp, OperatorBase>(m, "RecurrentOp")
.def_static(
"create",
[](py::bytes protobin) -> operators::RecurrentOp * {
OpDesc desc;
PADDLE_ENFORCE(desc.ParsePartialFromString(protobin),
"Cannot parse user input to OpDesc");
PADDLE_ENFORCE(desc.IsInitialized(),
"User OpDesc is not initialized, reason %s",
desc.InitializationErrorString());
auto rnn_op = OpRegistry::CreateOp(desc);
return static_cast<operators::RecurrentOp *>(rnn_op.release());
})
.def("set_stepnet", [](operators::RecurrentOp &self,
const operators::NetOp &net) -> void {
self.set_stepnet(net.Clone());
});
ExposeOperator(rnn);
m.def("unique_integer", UniqueIntegerGenerator);
......
......@@ -41,9 +41,7 @@ class NetOp : public framework::OperatorBase {
NetOp(const std::string& type, const VarNameMap& inputs,
const VarNameMap& outputs, const framework::AttributeMap& attrs);
NetOp(const NetOp& o)
: framework::OperatorBase(
static_cast<const framework::OperatorBase&>(o)) {
NetOp(const NetOp& o) : framework::OperatorBase(o.type_, {}, {}, o.attrs_) {
this->ops_.reserve(o.ops_.size());
std::transform(
o.ops_.begin(), o.ops_.end(), std::back_inserter(this->ops_),
......
......@@ -42,7 +42,7 @@ void RecurrentAlgorithm::InferShape(const Scope& scope) const {
rnn::LinkMemories(step_scopes, arg_->memories, i, -1,
true /*infer_shape_mode*/);
}
stepnet_->InferShape(*step_scopes[i]);
(*stepnet_)->InferShape(*step_scopes[i]);
}
rnn::ConcatOutputs(step_scopes, arg_->outlinks, seq_len_,
true /*infer_shape_mode*/);
......@@ -61,7 +61,7 @@ void RecurrentAlgorithm::Run(const Scope& scope,
rnn::LinkMemories(step_scopes, arg_->memories, step_id, -1,
false /*infer_shape_mode*/);
}
stepnet_->Run(*step_scopes[step_id], dev_ctx);
(*stepnet_)->Run(*step_scopes[step_id], dev_ctx);
}
rnn::ConcatOutputs(step_scopes, arg_->outlinks, seq_len_,
false /*infer_shape_mode*/);
......@@ -76,15 +76,15 @@ void RecurrentAlgorithm::CreateScopes(const Scope& scope) const {
// Now all variables in scope must be created outside of op.
PADDLE_ENFORCE_NOT_NULL(stepnet_);
PADDLE_ENFORCE(!stepnet_->Outputs().empty(), "stepnet_ op has no outputs");
PADDLE_ENFORCE(!stepnet_->Outputs().empty(), "net_op has no outputs");
PADDLE_ENFORCE(!(*stepnet_)->Outputs().empty(), "stepnet_ op has no outputs");
PADDLE_ENFORCE(!(*stepnet_)->Outputs().empty(), "net_op has no outputs");
if (seq_len_ > step_scopes->size()) {
for (size_t i = step_scopes->size(); i < seq_len_; ++i) {
auto& step_scope = scope.NewScope();
// create step net's temp inputs
for (auto& input : stepnet_->Inputs()) {
for (auto& input : (*stepnet_)->Inputs()) {
// the weight are located in parent scope
for (auto& var_name : input.second) {
if (!step_scope.FindVar(var_name)) {
......@@ -93,7 +93,7 @@ void RecurrentAlgorithm::CreateScopes(const Scope& scope) const {
}
}
// create stepnet's outputs
for (const auto& output : stepnet_->Outputs()) {
for (const auto& output : (*stepnet_)->Outputs()) {
for (auto& var_name : output.second) {
step_scope.NewVar(var_name);
}
......@@ -136,7 +136,7 @@ RecurrentOp::RecurrentOp(const std::string& type,
const framework::AttributeMap& attrs)
: OperatorBase(type, inputs, outputs, attrs) {
rnn::InitArgument(kArgName, &arg_, *this);
alg_.Init(&arg_, stepnet_.get());
alg_.Init(&arg_, &stepnet_);
}
class RecurrentAlgorithmProtoAndCheckerMaker
......@@ -178,7 +178,7 @@ void RecurrentGradientAlgorithm::Run(
rnn::LinkMemories(step_scopes, arg_->memories, step_id, 1,
false /*infer_shape_mode*/);
}
stepnet_->Run(*step_scopes[step_id], dev_ctx);
(*stepnet_)->Run(*step_scopes[step_id], dev_ctx);
}
LinkBootMemoryGradients(step_scopes[0], false);
rnn::ConcatOutputs(step_scopes, arg_->outlinks, seq_len_,
......@@ -215,7 +215,7 @@ void RecurrentGradientAlgorithm::InferShape(const Scope& scope) const {
rnn::LinkMemories(step_scopes, arg_->memories, step_id, 1,
true /*infer_shape_mode*/);
}
stepnet_->InferShape(*step_scopes[step_id]);
(*stepnet_)->InferShape(*step_scopes[step_id]);
}
rnn::ConcatOutputs(step_scopes, arg_->outlinks, seq_len_,
true /*infer_shape_mode*/);
......@@ -228,7 +228,7 @@ RecurrentGradientOp::RecurrentGradientOp(
const framework::AttributeMap& attrs)
: OperatorBase(type, inputs, outputs, attrs) {
rnn::InitArgument(kArgName, &arg_, *this);
alg_.Init(&arg_, stepnet_.get());
alg_.Init(&arg_, &stepnet_);
}
} // namespace operators
......
......@@ -34,7 +34,8 @@ class RecurrentAlgorithm {
void Run(const framework::Scope& scope,
const platform::DeviceContext& dev_ctx) const;
void Init(rnn::Argument* arg, framework::OperatorBase* stepnet) {
void Init(rnn::Argument* arg,
std::unique_ptr<framework::OperatorBase>* stepnet) {
PADDLE_ENFORCE_NOT_NULL(stepnet, "stepnet should be set before.");
arg_ = arg;
stepnet_ = stepnet;
......@@ -63,7 +64,7 @@ class RecurrentAlgorithm {
void InitMemories(framework::Scope* step_scopes, bool infer_shape_mode) const;
private:
framework::OperatorBase* stepnet_;
std::unique_ptr<framework::OperatorBase>* stepnet_;
rnn::Argument* arg_;
mutable size_t seq_len_;
};
......@@ -80,7 +81,8 @@ class RecurrentGradientAlgorithm {
* operator.
*/
public:
void Init(rnn::Argument* arg, framework::OperatorBase* stepnet) {
void Init(rnn::Argument* arg,
std::unique_ptr<framework::OperatorBase>* stepnet) {
PADDLE_ENFORCE_NOT_NULL(stepnet, "stepnet should be set before.");
arg_ = std::move(arg);
stepnet_ = stepnet;
......@@ -107,7 +109,7 @@ class RecurrentGradientAlgorithm {
private:
rnn::Argument* arg_;
mutable size_t seq_len_;
framework::OperatorBase* stepnet_;
std::unique_ptr<framework::OperatorBase>* stepnet_;
};
class RecurrentOp : public framework::OperatorBase {
......
......@@ -165,7 +165,6 @@ class GradientChecker(unittest.TestCase):
for no_grad in no_grad_set:
if no_grad not in in_names:
raise ValueError("no_grad should be in in_names")
backward_op = core.Operator.backward(forward_op, no_grad_set)
bwd_outputs = backward_op.outputs()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册