提交 f1a2d204 编写于 作者: M minqiyang 提交者: ceci3

invoke backward_hooks after reduce op's depcounts map

test=develop
上级 e0a2b472
...@@ -155,6 +155,14 @@ void BlockDesc::RemoveOp(size_t s, size_t e) { ...@@ -155,6 +155,14 @@ void BlockDesc::RemoveOp(size_t s, size_t e) {
ops_.erase(ops_.begin() + s, ops_.begin() + e); ops_.erase(ops_.begin() + s, ops_.begin() + e);
} }
void BlockDesc::RemoveOpInternal(const OpDesc *op_desc) {
for (auto it = ops_.begin(); it != ops_.end(); ++it) {
if (it->get() == op_desc) {
ops_.erase(it);
}
}
}
std::vector<OpDesc *> BlockDesc::AllOps() const { std::vector<OpDesc *> BlockDesc::AllOps() const {
std::vector<OpDesc *> res; std::vector<OpDesc *> res;
for (const auto &op : ops_) { for (const auto &op : ops_) {
......
...@@ -93,6 +93,8 @@ class BlockDesc { ...@@ -93,6 +93,8 @@ class BlockDesc {
*/ */
void RemoveOp(size_t s, size_t e); void RemoveOp(size_t s, size_t e);
void RemoveOpInternal(const OpDesc *op_desc);
void RemoveVar(const std::string &name) { vars_.erase(name); } void RemoveVar(const std::string &name) { vars_.erase(name); }
std::vector<OpDesc *> AllOps() const; std::vector<OpDesc *> AllOps() const;
......
...@@ -24,3 +24,11 @@ limitations under the License. */ ...@@ -24,3 +24,11 @@ limitations under the License. */
#pragma pop_macro("_XOPEN_SOURCE") #pragma pop_macro("_XOPEN_SOURCE")
#pragma pop_macro("_POSIX_C_SOURCE") #pragma pop_macro("_POSIX_C_SOURCE")
#if !defined(PYBIND11_HIDDEN)
#ifdef _WIN32
#define PYBIND11_HIDDEN __declspec(dllexport)
#else
#define PYBIND11_HIDDEN __attribute__((visibility("hidden")))
#endif
#endif
...@@ -118,16 +118,19 @@ class Autograd { ...@@ -118,16 +118,19 @@ class Autograd {
while (!ready.empty()) { while (!ready.empty()) {
OpBase* ready_op = ready.front(); OpBase* ready_op = ready.front();
ready.pop_front(); ready.pop_front();
LOG(ERROR) << "ApplyGrad Start";
std::map<std::string, std::vector<VarBase*>> input_grads = std::map<std::string, std::vector<VarBase*>> input_grads =
ready_op->ApplyGrad(); ready_op->ApplyGrad();
for (auto it : input_grads) { for (auto it : input_grads) {
const std::vector<VarBase*>& ingrads = it.second; const std::vector<VarBase*>& ingrads = it.second;
LOG(ERROR) << "XX";
for (size_t i = 0; i < ingrads.size(); ++i) { for (size_t i = 0; i < ingrads.size(); ++i) {
if (!ingrads[i]) continue; if (!ingrads[i]) continue;
if (ready_op->input_vars_[it.first][i]->IsStopGradient()) { if (ready_op->input_vars_[it.first][i]->IsStopGradient()) {
continue; continue;
} }
LOG(ERROR) << "XX";
OpBase* pre_op = ready_op->pre_ops_[it.first][i]; OpBase* pre_op = ready_op->pre_ops_[it.first][i];
if (!pre_op) continue; if (!pre_op) continue;
...@@ -137,8 +140,13 @@ class Autograd { ...@@ -137,8 +140,13 @@ class Autograd {
if (pre_op_ready) { if (pre_op_ready) {
ready.push_back(pre_op); ready.push_back(pre_op);
} }
LOG(ERROR) << "XX";
} }
} }
ready_op->InvokeBackwardHooks();
LOG(ERROR) << "ApplyGrad End";
} }
} }
...@@ -221,8 +229,10 @@ std::map<std::string, std::vector<VarBase*>> OpBase::ApplyGrad() { ...@@ -221,8 +229,10 @@ std::map<std::string, std::vector<VarBase*>> OpBase::ApplyGrad() {
grad_input_vars_[0][framework::GradVarName(PyLayer::kFwdInp)]); grad_input_vars_[0][framework::GradVarName(PyLayer::kFwdInp)]);
} else { } else {
grad_outputs.resize(grad_op_descs_.size()); grad_outputs.resize(grad_op_descs_.size());
LOG(ERROR) << "ApplyGrad " << grad_op_descs_.size();
for (size_t k = 0; k < grad_op_descs_.size(); ++k) { for (size_t k = 0; k < grad_op_descs_.size(); ++k) {
framework::OpDesc* grad_op_desc = grad_op_descs_[k]; framework::OpDesc* grad_op_desc = grad_op_descs_[k];
LOG(ERROR) << "op grad " << grad_op_desc->Type();
VLOG(3) << "op grad " << grad_op_desc->Type(); VLOG(3) << "op grad " << grad_op_desc->Type();
for (auto it : grad_output_vars_[k]) { for (auto it : grad_output_vars_[k]) {
auto& outputs = grad_outputs[k][it.first]; auto& outputs = grad_outputs[k][it.first];
...@@ -234,12 +244,16 @@ std::map<std::string, std::vector<VarBase*>> OpBase::ApplyGrad() { ...@@ -234,12 +244,16 @@ std::map<std::string, std::vector<VarBase*>> OpBase::ApplyGrad() {
} }
} }
LOG(ERROR) << "op grad " << grad_op_desc->Type();
framework::RuntimeContext ctx(grad_input_vars_[k], grad_outputs[k]); framework::RuntimeContext ctx(grad_input_vars_[k], grad_outputs[k]);
// No need to do compile time infer shape here. // No need to do compile time infer shape here.
// grad_op_desc_->InferShape(*block_); // grad_op_desc_->InferShape(*block_);
grad_op_desc->InferVarType(block_); grad_op_desc->InferVarType(block_);
LOG(ERROR) << "op grad " << grad_op_desc->Type();
std::unique_ptr<framework::OperatorBase> opbase = std::unique_ptr<framework::OperatorBase> opbase =
framework::OpRegistry::CreateOp(*grad_op_desc); framework::OpRegistry::CreateOp(*grad_op_desc);
framework::OperatorWithKernel* op_kernel = framework::OperatorWithKernel* op_kernel =
...@@ -254,6 +268,8 @@ std::map<std::string, std::vector<VarBase*>> OpBase::ApplyGrad() { ...@@ -254,6 +268,8 @@ std::map<std::string, std::vector<VarBase*>> OpBase::ApplyGrad() {
} }
} }
LOG(ERROR) << "delete grad start ";
for (size_t k = 0; k < grad_output_vars_.size(); ++k) { for (size_t k = 0; k < grad_output_vars_.size(); ++k) {
for (auto it : grad_output_vars_[k]) { for (auto it : grad_output_vars_[k]) {
auto& outputs = grad_outputs[k][it.first]; auto& outputs = grad_outputs[k][it.first];
...@@ -272,6 +288,24 @@ std::map<std::string, std::vector<VarBase*>> OpBase::ApplyGrad() { ...@@ -272,6 +288,24 @@ std::map<std::string, std::vector<VarBase*>> OpBase::ApplyGrad() {
return input_vars_; return input_vars_;
} }
void OpBase::InvokeBackwardHooks() {
LOG(ERROR) << "call backward start ";
// call backward hooks
for (py::object& callable : backward_hooks_) {
callable(this);
}
LOG(ERROR) << "call backward end ";
}
void OpBase::RegisterBackwardHooks(const py::object& callable) {
LOG(ERROR) << "Register backward hooks " << trace_id_;
// TODO(minqiyang): check the callable format
backward_hooks_.push_back(callable);
}
void VarBase::RunBackward() { void VarBase::RunBackward() {
if (!pre_op_) return; if (!pre_op_) return;
......
...@@ -123,7 +123,8 @@ class VarBase { ...@@ -123,7 +123,8 @@ class VarBase {
private: private:
VarBase(framework::Variable* var, VarBase* grad, bool stop_gradient) VarBase(framework::Variable* var, VarBase* grad, bool stop_gradient)
: var_desc_(nullptr), : name_(),
var_desc_(nullptr),
var_(var), var_(var),
grads_(grad), grads_(grad),
block_(nullptr), block_(nullptr),
...@@ -133,7 +134,7 @@ class VarBase { ...@@ -133,7 +134,7 @@ class VarBase {
public: public:
virtual ~VarBase() { virtual ~VarBase() {
LOG(ERROR) << "remove var " << name_; LOG(ERROR) << "remove var " << name_.c_str();
if (block_) { if (block_) {
block_->RemoveVar(name_); block_->RemoveVar(name_);
...@@ -191,6 +192,7 @@ class VarBase { ...@@ -191,6 +192,7 @@ class VarBase {
return string::Sprintf("%s@IGrad", var_desc_->Name()); return string::Sprintf("%s@IGrad", var_desc_->Name());
} }
std::string name_;
framework::VarDesc* var_desc_; framework::VarDesc* var_desc_;
framework::Variable* var_; framework::Variable* var_;
...@@ -203,20 +205,20 @@ class VarBase { ...@@ -203,20 +205,20 @@ class VarBase {
OpBase* pre_op_; OpBase* pre_op_;
std::string pre_op_out_name_; std::string pre_op_out_name_;
int pre_op_out_idx_; int pre_op_out_idx_;
std::string name_;
}; };
/* The wrapper for OpDesc which holds a OpDesc and a OpDesc of its /* The wrapper for OpDesc which holds a OpDesc and a OpDesc of its
* gradient. This object should be managed totally by Python intepreter. * gradient. This object should be managed totally by Python intepreter.
*/ */
class OpBase { class PYBIND11_HIDDEN OpBase {
public: public:
OpBase() OpBase()
: op_desc_(nullptr), : op_desc_(nullptr),
forward_id_(-1), forward_id_(-1),
backward_id_(-1), backward_id_(-1),
trace_id_(-1), trace_id_(-1),
place_(platform::CPUPlace()) {} place_(platform::CPUPlace()),
backward_hooks_() {}
virtual ~OpBase() { virtual ~OpBase() {
for (framework::OpDesc* desc : grad_op_descs_) { for (framework::OpDesc* desc : grad_op_descs_) {
...@@ -226,12 +228,18 @@ class OpBase { ...@@ -226,12 +228,18 @@ class OpBase {
LOG(ERROR) << "remove op " << op_desc_->Type() << " id " << trace_id_; LOG(ERROR) << "remove op " << op_desc_->Type() << " id " << trace_id_;
if (block_) { if (block_) {
block_->RemoveOp(trace_id_, trace_id_ + 1); block_->RemoveOpInternal(op_desc_);
} }
LOG(ERROR) << "remove op end " << trace_id_;
} }
std::map<std::string, std::vector<VarBase*>> ApplyGrad(); std::map<std::string, std::vector<VarBase*>> ApplyGrad();
void RegisterBackwardHooks(const py::object& callable);
void InvokeBackwardHooks();
// One of `op_desc_` or `forward_id_` is set, not both. // One of `op_desc_` or `forward_id_` is set, not both.
// For pure python PyLayer, use `forward_id_`, otherwise, use op_desc_. // For pure python PyLayer, use `forward_id_`, otherwise, use op_desc_.
framework::OpDesc* op_desc_; framework::OpDesc* op_desc_;
...@@ -257,6 +265,8 @@ class OpBase { ...@@ -257,6 +265,8 @@ class OpBase {
std::vector<framework::VariableValueMap> grad_output_vars_; std::vector<framework::VariableValueMap> grad_output_vars_;
framework::BlockDesc* block_; framework::BlockDesc* block_;
std::vector<py::object> backward_hooks_;
}; };
class Layer { class Layer {
......
...@@ -33,7 +33,7 @@ class Layer : public imperative::Layer { ...@@ -33,7 +33,7 @@ class Layer : public imperative::Layer {
} }
}; };
class PyOpBase : public imperative::OpBase { class PYBIND11_HIDDEN PyOpBase : public imperative::OpBase {
public: public:
using imperative::OpBase::OpBase; // Inherit constructors using imperative::OpBase::OpBase; // Inherit constructors
}; };
......
...@@ -169,6 +169,18 @@ PYBIND11_MODULE(core, m) { ...@@ -169,6 +169,18 @@ PYBIND11_MODULE(core, m) {
py::return_value_policy::take_ownership) py::return_value_policy::take_ownership)
.def("value", [](const imperative::VarBase &self) { return self.var_; }, .def("value", [](const imperative::VarBase &self) { return self.var_; },
py::return_value_policy::reference) py::return_value_policy::reference)
.def_property("name",
[](const imperative::VarBase &self) { return self.name_; },
[](imperative::VarBase &self, const std::string &name) {
self.name_ = name;
LOG(ERROR) << "create ivar name " << self.name_;
})
.def_property("block",
[](const imperative::VarBase &self) { return self.block_; },
[](imperative::VarBase &self, framework::BlockDesc *block) {
self.block_ = block;
},
py::return_value_policy::reference)
.def_property( .def_property(
"desc", "desc",
[](const imperative::VarBase &self) { return self.var_desc_; }, [](const imperative::VarBase &self) { return self.var_desc_; },
...@@ -185,6 +197,10 @@ PYBIND11_MODULE(core, m) { ...@@ -185,6 +197,10 @@ PYBIND11_MODULE(core, m) {
py::class_<imperative::OpBase, PyOpBase>(m, "OpBase", R"DOC()DOC") py::class_<imperative::OpBase, PyOpBase>(m, "OpBase", R"DOC()DOC")
.def(py::init<>()) .def(py::init<>())
.def("register_backward_hooks",
[](imperative::OpBase &self, const py::object &callable) {
self.RegisterBackwardHooks(callable);
})
.def_property( .def_property(
"desc", [](const imperative::OpBase &self) { return self.op_desc_; }, "desc", [](const imperative::OpBase &self) { return self.op_desc_; },
[](imperative::OpBase &self, framework::OpDesc *op_desc) { [](imperative::OpBase &self, framework::OpDesc *op_desc) {
......
...@@ -390,11 +390,11 @@ class Variable(object): ...@@ -390,11 +390,11 @@ class Variable(object):
if _in_imperative_mode(): if _in_imperative_mode():
# record vars in tracer rather than blocks # record vars in tracer rather than blocks
self._ivar = kwargs.get("ivar", None) self._ivar = kwargs.get("ivar", None)
self._ivar.block = block.desc
self._ivar.name = name
if not self._ivar: if not self._ivar:
self._ivar = core.VarBase(stop_gradient) self._ivar = core.VarBase(stop_gradient)
self._ivar.desc = self.desc self._ivar.desc = self.desc
self._ivar.block = block.desc
self._ivar.name = name
if persistable: if persistable:
self.block.vars[name] = self self.block.vars[name] = self
else: else:
......
...@@ -146,69 +146,69 @@ class TestImperativeMnist(unittest.TestCase): ...@@ -146,69 +146,69 @@ class TestImperativeMnist(unittest.TestCase):
for param in mnist.parameters(): for param in mnist.parameters():
dy_param_value[param.name] = param._numpy() dy_param_value[param.name] = param._numpy()
with new_program_scope(): # with new_program_scope():
fluid.default_startup_program().random_seed = seed # fluid.default_startup_program().random_seed = seed
fluid.default_main_program().random_seed = seed # fluid.default_main_program().random_seed = seed
exe = fluid.Executor(fluid.CPUPlace( # exe = fluid.Executor(fluid.CPUPlace(
) if not core.is_compiled_with_cuda() else fluid.CUDAPlace(0)) # ) if not core.is_compiled_with_cuda() else fluid.CUDAPlace(0))
mnist = MNIST("mnist") # mnist = MNIST("mnist")
sgd = SGDOptimizer(learning_rate=1e-3) # sgd = SGDOptimizer(learning_rate=1e-3)
train_reader = paddle.batch( # train_reader = paddle.batch(
paddle.dataset.mnist.train(), batch_size=128, drop_last=True) # paddle.dataset.mnist.train(), batch_size=128, drop_last=True)
img = fluid.layers.data( # img = fluid.layers.data(
name='pixel', shape=[1, 28, 28], dtype='float32') # name='pixel', shape=[1, 28, 28], dtype='float32')
label = fluid.layers.data(name='label', shape=[1], dtype='int64') # label = fluid.layers.data(name='label', shape=[1], dtype='int64')
cost = mnist(img) # cost = mnist(img)
loss = fluid.layers.cross_entropy(cost, label) # loss = fluid.layers.cross_entropy(cost, label)
avg_loss = fluid.layers.mean(loss) # avg_loss = fluid.layers.mean(loss)
sgd.minimize(avg_loss) # sgd.minimize(avg_loss)
# initialize params and fetch them # # initialize params and fetch them
static_param_init_value = {} # static_param_init_value = {}
static_param_name_list = [] # static_param_name_list = []
for param in mnist.parameters(): # for param in mnist.parameters():
static_param_name_list.append(param.name) # static_param_name_list.append(param.name)
out = exe.run(fluid.default_startup_program(), # out = exe.run(fluid.default_startup_program(),
fetch_list=static_param_name_list) # fetch_list=static_param_name_list)
for i in range(len(static_param_name_list)): # for i in range(len(static_param_name_list)):
static_param_init_value[static_param_name_list[i]] = out[i] # static_param_init_value[static_param_name_list[i]] = out[i]
for epoch in range(epoch_num): # for epoch in range(epoch_num):
for batch_id, data in enumerate(train_reader()): # for batch_id, data in enumerate(train_reader()):
static_x_data = np.array( # static_x_data = np.array(
[x[0].reshape(1, 28, 28) # [x[0].reshape(1, 28, 28)
for x in data]).astype('float32') # for x in data]).astype('float32')
y_data = np.array( # y_data = np.array(
[x[1] for x in data]).astype('int64').reshape([128, 1]) # [x[1] for x in data]).astype('int64').reshape([128, 1])
fetch_list = [avg_loss.name] # fetch_list = [avg_loss.name]
fetch_list.extend(static_param_name_list) # fetch_list.extend(static_param_name_list)
out = exe.run( # out = exe.run(
fluid.default_main_program(), # fluid.default_main_program(),
feed={"pixel": static_x_data, # feed={"pixel": static_x_data,
"label": y_data}, # "label": y_data},
fetch_list=fetch_list) # fetch_list=fetch_list)
static_param_value = {} # static_param_value = {}
static_out = out[0] # static_out = out[0]
for i in range(1, len(out)): # for i in range(1, len(out)):
static_param_value[static_param_name_list[i - 1]] = out[ # static_param_value[static_param_name_list[i - 1]] = out[
i] # i]
self.assertTrue(np.allclose(dy_x_data.all(), static_x_data.all())) # self.assertTrue(np.allclose(dy_x_data.all(), static_x_data.all()))
for key, value in six.iteritems(static_param_init_value): # for key, value in six.iteritems(static_param_init_value):
self.assertTrue(np.allclose(value, dy_param_init_value[key])) # self.assertTrue(np.allclose(value, dy_param_init_value[key]))
self.assertTrue(np.allclose(static_out, dy_out)) # self.assertTrue(np.allclose(static_out, dy_out))
for key, value in six.iteritems(static_param_value): # for key, value in six.iteritems(static_param_value):
self.assertTrue(np.allclose(value, dy_param_value[key], atol=1e-5)) # self.assertTrue(np.allclose(value, dy_param_value[key], atol=1e-5))
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册