提交 2120d1ea 编写于 作者: K kingfo

add InsertGradientOf operator support in pynative mode

上级 6a5c00ff
......@@ -58,7 +58,8 @@ using mindspore::tensor::TensorPy;
const char SINGLE_OP_GRAPH[] = "single_op_graph";
// primitive unable to infer value for constant input in PyNative mode
const std::set<std::string> vm_operators = {"make_ref", "HookBackward", "stop_gradient", "mixed_precision_cast"};
const std::set<std::string> vm_operators = {"make_ref", "HookBackward", "InsertGradientOf", "stop_gradient",
"mixed_precision_cast"};
namespace mindspore {
namespace pynative {
......@@ -346,7 +347,7 @@ py::object RunOpInVM(const OpExecInfoPtr &op_exec_info, PynativeStatusCode *stat
MS_EXCEPTION_IF_NULL(op_exec_info->py_primitive);
auto &op_inputs = op_exec_info->op_inputs;
if (op_exec_info->op_name == "HookBackward") {
if (op_exec_info->op_name == "HookBackward" || op_exec_info->op_name == "InsertGradientOf") {
py::tuple result(op_inputs.size());
for (size_t i = 0; i < op_inputs.size(); i++) {
py::object input = op_inputs[i];
......
......@@ -238,10 +238,6 @@ class InsertGradientOf(PrimitiveWithInfer):
def __init__(self, f):
self.f = f
def __call__(self, x):
"""run in PyNative mode."""
return x
def infer_shape(self, x_shape):
return x_shape
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册