提交 44b656a9 编写于 作者: K kingfo

add data sync before hook function

上级 9cbed69e
......@@ -602,6 +602,19 @@ void FinalVM::InstPushPrim(const VectorRef &args) {
MS_LOG(DEBUG) << "End";
}
void FinalVM::SyncData(const py::object &arg) {
if (py::isinstance<py::tuple>(arg)) {
py::tuple arg_list = py::cast<py::tuple>(arg);
for (size_t i = 0; i < arg_list.size(); i++) {
SyncData(arg_list[i]);
}
}
if (py::isinstance<tensor::Tensor>(arg)) {
auto tensor = py::cast<tensor::TensorPtr>(arg);
(void)tensor->data_sync();
}
}
BaseRef FinalVM::RunHook(const PrimitivePtr &prim, const VectorRef &args) {
MS_LOG(DEBUG) << "input for operation:";
std::size_t args_size = args.size();
......@@ -612,15 +625,20 @@ BaseRef FinalVM::RunHook(const PrimitivePtr &prim, const VectorRef &args) {
MS_LOG(DEBUG) << "arg: " << i << ":";
i++;
}
// Hook operator for execute cell custom bprop function
py::object obj;
bool is_bprop = prim->HasAttr("bprop");
if (is_bprop) {
SyncData(py_args);
py::function fn_bprop = prim->hook();
obj = fn_bprop(*py_args);
return obj;
}
// Sync gradient data from device to host
SyncData(py_args[2]);
bool is_cell = prim->HasAttr("cell_hook");
if (is_cell) {
// Hook operator for execute cell hook function
std::string cell_id = GetValue<std::string>(prim->GetAttr("cell_id"));
if (_hook_grad.find(cell_id) != _hook_grad.end()) {
std::size_t hook_args_size = 3;
......@@ -639,6 +657,7 @@ BaseRef FinalVM::RunHook(const PrimitivePtr &prim, const VectorRef &args) {
obj = py_args[2];
}
} else {
// Hook operator for execute variable hook function
py::function fn_hook = prim->hook();
obj = fn_hook(py::make_tuple(py_args[2]));
if (py::isinstance<py::none>(obj)) {
......
......@@ -115,7 +115,7 @@ class FinalVM {
void InstPushPrim(const VectorRef &args);
void InstSwitchReturn(const VectorRef &args);
void set_insts(const InstSet &value) { insts_ = value; }
BaseRef RunHook(const PrimitivePtr &prim, const VectorRef &args);
BaseRef RunHook(const PrimitivePtr &prim, const VectorRef &arg);
protected:
BaseRef Ref(int i);
......@@ -129,6 +129,7 @@ class FinalVM {
void PushStatus(bool is_switch_call);
bool PopStatus();
void DoJmp(const BaseRef &jmp);
void SyncData(const py::object &args);
void MergeJmpArgs(const BaseRef &jmp, const BaseRef &c);
BaseRef MergeArgs(const BaseRef &first, const BaseRef &second);
......
......@@ -77,7 +77,7 @@ class Cell:
if flags:
self.add_flags(**flags)
self._backward_hook = None
self._enable_hook = False
self.enable_hook = False
self._bprop_debug = False
@property
......@@ -97,10 +97,24 @@ class Cell:
@property
def bprop_debug(self):
"""
Get whether cell custom bprop debug is enabled.
"""
return self._bprop_debug
@bprop_debug.setter
def bprop_debug(self, value):
"""
Set whether to enable cell custom bprop debug.
Note:
When bprop is defined in cell, the bprop function will be executed
in python interpreter when bprop debug is true, and will be parsed
and add to graph when bprop debug is false.
Args:
value (bool): Specifies whether to enable bprop debug. Default: False.
"""
if not isinstance(value, bool):
raise TypeError("'bprop debug' value must be bool type.")
self._bprop_debug = value
......@@ -755,17 +769,19 @@ class Cell:
outputs = self._backward_hook(inputs)
return outputs
@property
def enable_hook(self):
"""Whether the cell register hook function"""
return self._enable_hook
def register_backward_hook(self, fn):
"""
Set the cell backward hook function.
Note:
fn should be defined as following code shows, `cell_name` is the name of registered cell,
`grad_input` is gradient passed to the cell, `grad_output` is the gradient computed and pass to
next cell or primitve, which may be modified and return.
>>> hook_fn(cell_name, grad_input, grad_output) -> Tensor or None
Args:
fn (function): Specifies the hook function with grad as input.
"""
self._backward_hook = HookBackward(fn, self.cls_name + "(" + str(id(self)) + ")")
self._enable_hook = True
......@@ -247,9 +247,11 @@ class HookBackward(PrimitiveWithInfer):
Used as tag to hook gradient in intermediate variables.
Note:
The hook function should have one input of gradient of the variable.
hook function will be executed in python environment, while callback
of InsertGradientOf will be parsed and added to the graph.
The hook function should be defined like `hook_fn(grad) -> Tensor or None`,
which grad is the gradient passed to the primitive and gradient may be
modified and passed to nex primitive. the difference between hook function and
callback of InsertGradientOf is that hook function is executed in python
environment while callback will be parsed and added to the graph.
Args:
hook_fn (Function): Python function. hook function.
......@@ -312,6 +314,8 @@ class Print(PrimitiveWithInfer):
2. The data of tensor is a scalar type.
In pynative mode, please use python print function.
Inputs:
- **input_x** (Union[Tensor, str]) - The graph node to attach to. The input supports
multiple strings and tensors which are separated by ','.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册