提交 f1a2d204 编写于 作者: M minqiyang 提交者: ceci3

invoke backward_hooks after reduce op's depcounts map

test=develop
上级 e0a2b472
......@@ -155,6 +155,14 @@ void BlockDesc::RemoveOp(size_t s, size_t e) {
ops_.erase(ops_.begin() + s, ops_.begin() + e);
}
void BlockDesc::RemoveOpInternal(const OpDesc *op_desc) {
for (auto it = ops_.begin(); it != ops_.end(); ++it) {
if (it->get() == op_desc) {
ops_.erase(it);
}
}
}
std::vector<OpDesc *> BlockDesc::AllOps() const {
std::vector<OpDesc *> res;
for (const auto &op : ops_) {
......
......@@ -93,6 +93,8 @@ class BlockDesc {
*/
void RemoveOp(size_t s, size_t e);
void RemoveOpInternal(const OpDesc *op_desc);
void RemoveVar(const std::string &name) { vars_.erase(name); }
std::vector<OpDesc *> AllOps() const;
......
......@@ -24,3 +24,11 @@ limitations under the License. */
#pragma pop_macro("_XOPEN_SOURCE")
#pragma pop_macro("_POSIX_C_SOURCE")
#if !defined(PYBIND11_HIDDEN)
#ifdef _WIN32
#define PYBIND11_HIDDEN __declspec(dllexport)
#else
#define PYBIND11_HIDDEN __attribute__((visibility("hidden")))
#endif
#endif
......@@ -118,16 +118,19 @@ class Autograd {
while (!ready.empty()) {
OpBase* ready_op = ready.front();
ready.pop_front();
LOG(ERROR) << "ApplyGrad Start";
std::map<std::string, std::vector<VarBase*>> input_grads =
ready_op->ApplyGrad();
for (auto it : input_grads) {
const std::vector<VarBase*>& ingrads = it.second;
LOG(ERROR) << "XX";
for (size_t i = 0; i < ingrads.size(); ++i) {
if (!ingrads[i]) continue;
if (ready_op->input_vars_[it.first][i]->IsStopGradient()) {
continue;
}
LOG(ERROR) << "XX";
OpBase* pre_op = ready_op->pre_ops_[it.first][i];
if (!pre_op) continue;
......@@ -137,8 +140,13 @@ class Autograd {
if (pre_op_ready) {
ready.push_back(pre_op);
}
LOG(ERROR) << "XX";
}
}
ready_op->InvokeBackwardHooks();
LOG(ERROR) << "ApplyGrad End";
}
}
......@@ -221,8 +229,10 @@ std::map<std::string, std::vector<VarBase*>> OpBase::ApplyGrad() {
grad_input_vars_[0][framework::GradVarName(PyLayer::kFwdInp)]);
} else {
grad_outputs.resize(grad_op_descs_.size());
LOG(ERROR) << "ApplyGrad " << grad_op_descs_.size();
for (size_t k = 0; k < grad_op_descs_.size(); ++k) {
framework::OpDesc* grad_op_desc = grad_op_descs_[k];
LOG(ERROR) << "op grad " << grad_op_desc->Type();
VLOG(3) << "op grad " << grad_op_desc->Type();
for (auto it : grad_output_vars_[k]) {
auto& outputs = grad_outputs[k][it.first];
......@@ -234,12 +244,16 @@ std::map<std::string, std::vector<VarBase*>> OpBase::ApplyGrad() {
}
}
LOG(ERROR) << "op grad " << grad_op_desc->Type();
framework::RuntimeContext ctx(grad_input_vars_[k], grad_outputs[k]);
// No need to do compile time infer shape here.
// grad_op_desc_->InferShape(*block_);
grad_op_desc->InferVarType(block_);
LOG(ERROR) << "op grad " << grad_op_desc->Type();
std::unique_ptr<framework::OperatorBase> opbase =
framework::OpRegistry::CreateOp(*grad_op_desc);
framework::OperatorWithKernel* op_kernel =
......@@ -254,6 +268,8 @@ std::map<std::string, std::vector<VarBase*>> OpBase::ApplyGrad() {
}
}
LOG(ERROR) << "delete grad start ";
for (size_t k = 0; k < grad_output_vars_.size(); ++k) {
for (auto it : grad_output_vars_[k]) {
auto& outputs = grad_outputs[k][it.first];
......@@ -272,6 +288,24 @@ std::map<std::string, std::vector<VarBase*>> OpBase::ApplyGrad() {
return input_vars_;
}
void OpBase::InvokeBackwardHooks() {
LOG(ERROR) << "call backward start ";
// call backward hooks
for (py::object& callable : backward_hooks_) {
callable(this);
}
LOG(ERROR) << "call backward end ";
}
void OpBase::RegisterBackwardHooks(const py::object& callable) {
LOG(ERROR) << "Register backward hooks " << trace_id_;
// TODO(minqiyang): check the callable format
backward_hooks_.push_back(callable);
}
void VarBase::RunBackward() {
if (!pre_op_) return;
......
......@@ -123,7 +123,8 @@ class VarBase {
private:
VarBase(framework::Variable* var, VarBase* grad, bool stop_gradient)
: var_desc_(nullptr),
: name_(),
var_desc_(nullptr),
var_(var),
grads_(grad),
block_(nullptr),
......@@ -133,7 +134,7 @@ class VarBase {
public:
virtual ~VarBase() {
LOG(ERROR) << "remove var " << name_;
LOG(ERROR) << "remove var " << name_.c_str();
if (block_) {
block_->RemoveVar(name_);
......@@ -191,6 +192,7 @@ class VarBase {
return string::Sprintf("%s@IGrad", var_desc_->Name());
}
std::string name_;
framework::VarDesc* var_desc_;
framework::Variable* var_;
......@@ -203,20 +205,20 @@ class VarBase {
OpBase* pre_op_;
std::string pre_op_out_name_;
int pre_op_out_idx_;
std::string name_;
};
/* The wrapper for OpDesc which holds a OpDesc and a OpDesc of its
* gradient. This object should be managed totally by Python intepreter.
*/
class OpBase {
class PYBIND11_HIDDEN OpBase {
public:
OpBase()
: op_desc_(nullptr),
forward_id_(-1),
backward_id_(-1),
trace_id_(-1),
place_(platform::CPUPlace()) {}
place_(platform::CPUPlace()),
backward_hooks_() {}
virtual ~OpBase() {
for (framework::OpDesc* desc : grad_op_descs_) {
......@@ -226,12 +228,18 @@ class OpBase {
LOG(ERROR) << "remove op " << op_desc_->Type() << " id " << trace_id_;
if (block_) {
block_->RemoveOp(trace_id_, trace_id_ + 1);
block_->RemoveOpInternal(op_desc_);
}
LOG(ERROR) << "remove op end " << trace_id_;
}
std::map<std::string, std::vector<VarBase*>> ApplyGrad();
void RegisterBackwardHooks(const py::object& callable);
void InvokeBackwardHooks();
// One of `op_desc_` or `forward_id_` is set, not both.
// For pure python PyLayer, use `forward_id_`, otherwise, use op_desc_.
framework::OpDesc* op_desc_;
......@@ -257,6 +265,8 @@ class OpBase {
std::vector<framework::VariableValueMap> grad_output_vars_;
framework::BlockDesc* block_;
std::vector<py::object> backward_hooks_;
};
class Layer {
......
......@@ -33,7 +33,7 @@ class Layer : public imperative::Layer {
}
};
class PyOpBase : public imperative::OpBase {
class PYBIND11_HIDDEN PyOpBase : public imperative::OpBase {
public:
using imperative::OpBase::OpBase; // Inherit constructors
};
......
......@@ -169,6 +169,18 @@ PYBIND11_MODULE(core, m) {
py::return_value_policy::take_ownership)
.def("value", [](const imperative::VarBase &self) { return self.var_; },
py::return_value_policy::reference)
.def_property("name",
[](const imperative::VarBase &self) { return self.name_; },
[](imperative::VarBase &self, const std::string &name) {
self.name_ = name;
LOG(ERROR) << "create ivar name " << self.name_;
})
.def_property("block",
[](const imperative::VarBase &self) { return self.block_; },
[](imperative::VarBase &self, framework::BlockDesc *block) {
self.block_ = block;
},
py::return_value_policy::reference)
.def_property(
"desc",
[](const imperative::VarBase &self) { return self.var_desc_; },
......@@ -185,6 +197,10 @@ PYBIND11_MODULE(core, m) {
py::class_<imperative::OpBase, PyOpBase>(m, "OpBase", R"DOC()DOC")
.def(py::init<>())
.def("register_backward_hooks",
[](imperative::OpBase &self, const py::object &callable) {
self.RegisterBackwardHooks(callable);
})
.def_property(
"desc", [](const imperative::OpBase &self) { return self.op_desc_; },
[](imperative::OpBase &self, framework::OpDesc *op_desc) {
......@@ -415,11 +431,11 @@ PYBIND11_MODULE(core, m) {
Set LoD of the LoDTensor according to recursive sequence length.
For example, if recursive_sequence_lengths=[[2, 3]], meaning that
there are two sequences with length 2 and 3 respectively, the
corresponding lod would be [[0, 2, 2+3]], i.e, [[0, 2, 5]].
there are two sequences with length 2 and 3 respectively, the
corresponding lod would be [[0, 2, 2+3]], i.e, [[0, 2, 5]].
Args:
recursive_sequence_lengths (List[List[int]]): sequence lengths.
recursive_sequence_lengths (List[List[int]]): sequence lengths.
)DOC")
.def("lod",
[](LoDTensor &self) -> std::vector<std::vector<size_t>> {
......@@ -450,7 +466,7 @@ PYBIND11_MODULE(core, m) {
Return the sequence length of the LoDTensor corresponding to LoD.
Returns:
out (List[List[int]): the sequence lengths.
out (List[List[int]): the sequence lengths.
)DOC")
.def("has_valid_recursive_sequence_lengths",
[](LoDTensor &self) -> bool {
......@@ -601,29 +617,29 @@ All parameter, weight, gradient are variables in Paddle.
},
py::arg("name"),
R"DOC(
Find or create variable named :code:`name` in the current scope.
Find or create variable named :code:`name` in the current scope.
If the variable named :code:`name` does not exist in the
If the variable named :code:`name` does not exist in the
current scope, the variable would be created. Otherwise,
return the existing variable.
return the existing variable.
Args:
name (str): the variable name.
name (str): the variable name.
Returns:
out (core.Variable): the found or created variable.
out (core.Variable): the found or created variable.
)DOC",
py::return_value_policy::reference)
.def("find_var", &Scope::FindVar, py::arg("name"),
R"DOC(
Find variable named :code:`name` in the current scope or
Find variable named :code:`name` in the current scope or
its parent scope. Return None if not found.
Args:
name (str): the variable name.
Returns:
out (core.Variable|None): the found variable or None.
out (core.Variable|None): the found variable or None.
)DOC",
py::return_value_policy::reference)
.def("new_scope", [](Scope &self) -> Scope * { return &self.NewScope(); },
......@@ -647,7 +663,7 @@ All parameter, weight, gradient are variables in Paddle.
},
R"DOC(
Create a new scope.
Returns:
out (core._Scope): the created scope.
)DOC",
......
......@@ -390,11 +390,11 @@ class Variable(object):
if _in_imperative_mode():
# record vars in tracer rather than blocks
self._ivar = kwargs.get("ivar", None)
self._ivar.block = block.desc
self._ivar.name = name
if not self._ivar:
self._ivar = core.VarBase(stop_gradient)
self._ivar.desc = self.desc
self._ivar.block = block.desc
self._ivar.name = name
if persistable:
self.block.vars[name] = self
else:
......
......@@ -146,69 +146,69 @@ class TestImperativeMnist(unittest.TestCase):
for param in mnist.parameters():
dy_param_value[param.name] = param._numpy()
with new_program_scope():
fluid.default_startup_program().random_seed = seed
fluid.default_main_program().random_seed = seed
exe = fluid.Executor(fluid.CPUPlace(
) if not core.is_compiled_with_cuda() else fluid.CUDAPlace(0))
mnist = MNIST("mnist")
sgd = SGDOptimizer(learning_rate=1e-3)
train_reader = paddle.batch(
paddle.dataset.mnist.train(), batch_size=128, drop_last=True)
img = fluid.layers.data(
name='pixel', shape=[1, 28, 28], dtype='float32')
label = fluid.layers.data(name='label', shape=[1], dtype='int64')
cost = mnist(img)
loss = fluid.layers.cross_entropy(cost, label)
avg_loss = fluid.layers.mean(loss)
sgd.minimize(avg_loss)
# initialize params and fetch them
static_param_init_value = {}
static_param_name_list = []
for param in mnist.parameters():
static_param_name_list.append(param.name)
out = exe.run(fluid.default_startup_program(),
fetch_list=static_param_name_list)
for i in range(len(static_param_name_list)):
static_param_init_value[static_param_name_list[i]] = out[i]
for epoch in range(epoch_num):
for batch_id, data in enumerate(train_reader()):
static_x_data = np.array(
[x[0].reshape(1, 28, 28)
for x in data]).astype('float32')
y_data = np.array(
[x[1] for x in data]).astype('int64').reshape([128, 1])
fetch_list = [avg_loss.name]
fetch_list.extend(static_param_name_list)
out = exe.run(
fluid.default_main_program(),
feed={"pixel": static_x_data,
"label": y_data},
fetch_list=fetch_list)
static_param_value = {}
static_out = out[0]
for i in range(1, len(out)):
static_param_value[static_param_name_list[i - 1]] = out[
i]
self.assertTrue(np.allclose(dy_x_data.all(), static_x_data.all()))
for key, value in six.iteritems(static_param_init_value):
self.assertTrue(np.allclose(value, dy_param_init_value[key]))
self.assertTrue(np.allclose(static_out, dy_out))
for key, value in six.iteritems(static_param_value):
self.assertTrue(np.allclose(value, dy_param_value[key], atol=1e-5))
# with new_program_scope():
# fluid.default_startup_program().random_seed = seed
# fluid.default_main_program().random_seed = seed
# exe = fluid.Executor(fluid.CPUPlace(
# ) if not core.is_compiled_with_cuda() else fluid.CUDAPlace(0))
# mnist = MNIST("mnist")
# sgd = SGDOptimizer(learning_rate=1e-3)
# train_reader = paddle.batch(
# paddle.dataset.mnist.train(), batch_size=128, drop_last=True)
# img = fluid.layers.data(
# name='pixel', shape=[1, 28, 28], dtype='float32')
# label = fluid.layers.data(name='label', shape=[1], dtype='int64')
# cost = mnist(img)
# loss = fluid.layers.cross_entropy(cost, label)
# avg_loss = fluid.layers.mean(loss)
# sgd.minimize(avg_loss)
# # initialize params and fetch them
# static_param_init_value = {}
# static_param_name_list = []
# for param in mnist.parameters():
# static_param_name_list.append(param.name)
# out = exe.run(fluid.default_startup_program(),
# fetch_list=static_param_name_list)
# for i in range(len(static_param_name_list)):
# static_param_init_value[static_param_name_list[i]] = out[i]
# for epoch in range(epoch_num):
# for batch_id, data in enumerate(train_reader()):
# static_x_data = np.array(
# [x[0].reshape(1, 28, 28)
# for x in data]).astype('float32')
# y_data = np.array(
# [x[1] for x in data]).astype('int64').reshape([128, 1])
# fetch_list = [avg_loss.name]
# fetch_list.extend(static_param_name_list)
# out = exe.run(
# fluid.default_main_program(),
# feed={"pixel": static_x_data,
# "label": y_data},
# fetch_list=fetch_list)
# static_param_value = {}
# static_out = out[0]
# for i in range(1, len(out)):
# static_param_value[static_param_name_list[i - 1]] = out[
# i]
# self.assertTrue(np.allclose(dy_x_data.all(), static_x_data.all()))
# for key, value in six.iteritems(static_param_init_value):
# self.assertTrue(np.allclose(value, dy_param_init_value[key]))
# self.assertTrue(np.allclose(static_out, dy_out))
# for key, value in six.iteritems(static_param_value):
# self.assertTrue(np.allclose(value, dy_param_value[key], atol=1e-5))
if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册