提交 a5e66e15 编写于 作者: K kingfo

change hook grad input to tuple

上级 e5c45bd3
...@@ -624,8 +624,8 @@ BaseRef FinalVM::RunHook(const PrimitivePtr &prim, const VectorRef &args) { ...@@ -624,8 +624,8 @@ BaseRef FinalVM::RunHook(const PrimitivePtr &prim, const VectorRef &args) {
if (_hook_grad.find(cell_id) != _hook_grad.end()) { if (_hook_grad.find(cell_id) != _hook_grad.end()) {
py::tuple hook_args = py::tuple(3); py::tuple hook_args = py::tuple(3);
hook_args[0] = cell_id; hook_args[0] = cell_id;
hook_args[1] = _hook_grad[cell_id]; hook_args[1] = py::make_tuple(_hook_grad[cell_id]);
hook_args[2] = py_args[2]; hook_args[2] = py::make_tuple(py_args[2]);
py::function fn_hook = prim->hook(); py::function fn_hook = prim->hook();
obj = fn_hook(*hook_args); obj = fn_hook(*hook_args);
if (py::isinstance<py::none>(obj)) { if (py::isinstance<py::none>(obj)) {
...@@ -638,7 +638,7 @@ BaseRef FinalVM::RunHook(const PrimitivePtr &prim, const VectorRef &args) { ...@@ -638,7 +638,7 @@ BaseRef FinalVM::RunHook(const PrimitivePtr &prim, const VectorRef &args) {
} }
} else { } else {
py::function fn_hook = prim->hook(); py::function fn_hook = prim->hook();
obj = fn_hook(py_args[2]); obj = fn_hook(py::make_tuple(py_args[2]));
if (py::isinstance<py::none>(obj)) { if (py::isinstance<py::none>(obj)) {
obj = py_args[2]; obj = py_args[2];
} }
......
...@@ -30,13 +30,13 @@ def weight_variable(): ...@@ -30,13 +30,13 @@ def weight_variable():
def cell_hook_function(cell_id, grad_input, grad_output): def cell_hook_function(cell_id, grad_input, grad_output):
print(cell_id) print(cell_id)
assert(grad_output.asnumpy().shape == (32, 6, 14, 14)) assert(grad_output[0].asnumpy().shape == (32, 6, 14, 14))
assert(grad_input.asnumpy().shape == (32, 16, 10, 10)) assert(grad_input[0].asnumpy().shape == (32, 16, 10, 10))
def var_hook_function(grad_out): def var_hook_function(grad_out):
print("grad:", grad_out) print("grad:", grad_out)
assert(grad_out.asnumpy().shape == (32, 120)) assert(grad_out[0].asnumpy().shape == (32, 120))
class LeNet5(nn.Cell): class LeNet5(nn.Cell):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册