提交 840922e5 编写于 作者: K kingfo

add backward hook function in pynative mode

上级 d402b944
......@@ -102,7 +102,10 @@ def get_parse_method_of_class(obj, parse_method=None):
method_name = parse_method
else:
if isinstance(obj, nn.Cell):
method_name = "construct"
if obj.enable_hook:
method_name = "_hook_construct"
else:
method_name = "construct"
if method_name is not None:
if hasattr(obj, method_name):
method = getattr(obj, method_name)
......
......@@ -115,6 +115,7 @@ REGISTER_PYBIND_DEFINE(Primitive_, ([](const py::module *m) {
.def("get_attr_dict", &PrimitivePy::GetAttrDict, "get primitive attr")
.def("set_prim_type", &PrimitivePy::set_prim_type, "Set primitive type.")
.def("set_signatures", &PrimitivePy::set_signatures, "Set primitive inputs signature.")
.def("register_hook", &PrimitivePy::set_hook, "Set primitive hook function.")
.def("set_instance_name", &PrimitivePy::set_instance_name, "Set primitive instance name.");
}));
} // namespace mindspore
......@@ -23,7 +23,6 @@
#include <string>
#include <tuple>
#include "pybind11/pybind11.h"
#include "pipeline/static_analysis/abstract_value.h"
#include "utils/misc.h"
#include "utils/log_adapter.h"
......@@ -31,8 +30,6 @@
#include "ir/signature.h"
#include "parallel/ops_info/operator_info.h"
namespace py = pybind11;
namespace mindspore {
class PrimitivePy : public Primitive {
public:
......
......@@ -24,6 +24,9 @@
#include <tuple>
#include "ir/dtype/type.h"
#include "pybind11/pybind11.h"
namespace py = pybind11;
namespace mindspore {
// Supported meta type
......@@ -73,6 +76,9 @@ class Primitive : public Named {
return iter == attrs_.cend() ? nullptr : iter->second;
}
void set_hook(const py::function &hook) { hook_ = hook; }
py::function hook() const { return hook_; }
const std::unordered_map<std::string, ValuePtr> &attrs() const { return attrs_; }
// if Primitive has any attribute, for Primitives like scalar_add, return, etc, don't have any attribute.
......@@ -103,6 +109,7 @@ class Primitive : public Named {
private:
std::string instance_name_;
py::function hook_;
bool is_base_;
bool has_signature_;
PrimType prim_type_;
......
......@@ -211,6 +211,7 @@ const PrimitivePtr kPrimRelu = std::make_shared<Primitive>("ReLU");
const PrimitivePtr kPrimReluV2 = std::make_shared<Primitive>("ReLUV2");
const PrimitivePtr kPrimZerosLikeTensor = std::make_shared<Primitive>("zeros_like_tensor");
const PrimitivePtr kPrimFakeBprop = std::make_shared<Primitive>("fake_bprop");
const PrimitivePtr kPrimBpropCut = std::make_shared<Primitive>("bprop_cut");
// Other miscellaneous
const PrimitivePtr kPrimIdentity = std::make_shared<Primitive>("identity");
......@@ -224,6 +225,7 @@ const PrimitivePtr kPrimGetRefKey = std::make_shared<Primitive>("get_ref_key");
const PrimitivePtr kPrimGetRefValue = std::make_shared<Primitive>("get_ref_value");
const PrimitivePtr kPrimGetRefOrigin = std::make_shared<Primitive>("get_ref_origin");
const PrimitivePtr kPrimInsertGradientOf = std::make_shared<Primitive>("InsertGradientOf");
const PrimitivePtr kPrimHookBackward = std::make_shared<Primitive>("HookBackward");
const PrimitivePtr kPrimPrintShapeType = std::make_shared<Primitive>("PrintShapeType");
const PrimitivePtr kPrimSameTypeShape = std::make_shared<Primitive>("SameTypeShape");
const PrimitivePtr kPrimCheckBprop = std::make_shared<Primitive>("CheckBprop");
......
......@@ -216,6 +216,7 @@ extern const PrimitivePtr kPrimReluV2;
extern const PrimitivePtr kPrimActivation;
extern const PrimitivePtr kPrimZerosLikeTensor;
extern const PrimitivePtr kPrimFakeBprop;
extern const PrimitivePtr kPrimBpropCut;
// Other Miscellaneous
extern const PrimitivePtr kPrimIdentity;
......@@ -230,6 +231,7 @@ extern const PrimitivePtr kPrimGetRefKey;
extern const PrimitivePtr kPrimGetRefValue;
extern const PrimitivePtr kPrimGetRefOrigin;
extern const PrimitivePtr kPrimInsertGradientOf;
extern const PrimitivePtr kPrimHookBackward;
extern const PrimitivePtr kPrimPrintShapeType;
extern const PrimitivePtr kPrimPrint;
extern const PrimitivePtr kPrimSameTypeShape;
......
......@@ -285,6 +285,16 @@ AbstractBasePtr InferImplFakeBprop(const AnalysisEnginePtr &, const PrimitivePtr
return args_spec_list[0]->Broaden();
}
AbstractBasePtr InferImplBpropCut(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
// Inputs: a tensor.
AbstractBasePtrList args_list;
for (size_t i = 0; i < args_spec_list.size() - 2; i++) {
args_list.push_back(args_spec_list[i]->Broaden());
}
return std::make_shared<AbstractTuple>(args_list);
}
AbstractBasePtr InferImplLayerNorm(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
// Inputs: three tensors(x, gamma, beta).
......
......@@ -32,6 +32,7 @@
#include "operator/ops.h"
#include "operator/composite/composite.h"
#include "utils/symbolic.h"
#include "utils/context/ms_context.h"
#include "./common.h"
namespace mindspore {
......
......@@ -125,6 +125,7 @@ class KPrim {
FuncGraphPtr GetBprop(const PrimitivePtr &prim);
FuncGraphPtr GetFprop(const PrimitivePtr &prim);
FuncGraphPtr FakeBprop(const ValueNodePtr &value_node, const pipeline::ResourceBasePtr &resources);
FuncGraphPtr BpropCut(const ValueNodePtr &value_node, const pipeline::ResourceBasePtr &resources);
// Given a bprop rule, do the K mapping.
template <typename T>
FuncGraphPtr BpropToK(const T &primal, const FuncGraphPtr &bprop_g);
......
......@@ -115,10 +115,15 @@ FuncGraphPtr KPrim::KPrimitive(const ValueNodePtr &value_node, const pipeline::R
}
bool is_faked_bprop = false;
auto bprop_fg = GetBprop(prim);
if (bprop_fg == nullptr) {
bprop_fg = FakeBprop(value_node, resources);
is_faked_bprop = true;
FuncGraphPtr bprop_fg = nullptr;
if (prim->Hash() == prim::kPrimHookBackward->Hash() && prim->name() == "HookBackward") {
bprop_fg = BpropCut(value_node, resources);
} else {
bprop_fg = GetBprop(prim);
if (bprop_fg == nullptr) {
bprop_fg = FakeBprop(value_node, resources);
is_faked_bprop = true;
}
}
auto expanded_fg = BpropToK(prim, bprop_fg);
......@@ -206,6 +211,45 @@ FuncGraphPtr KPrim::KUserDefinedCellBprop(const FuncGraphPtr bprop_fg) {
return expanded_fg;
}
FuncGraphPtr KPrim::BpropCut(const ValueNodePtr &value_node, const pipeline::ResourceBasePtr &resources) {
auto prim = GetValueNode<PrimitivePtr>(value_node);
MS_EXCEPTION_IF_NULL(prim);
auto &node_users = resources->manager()->node_users();
auto &users = node_users[value_node];
auto cnode = std::find_if(users.begin(), users.end(), [&prim](const std::pair<AnfNodePtr, int> &user) -> bool {
return IsPrimitiveCNode(user.first, prim);
});
if (cnode == users.end()) {
MS_LOG(EXCEPTION) << "Fail to find cnode.";
}
auto inputs_num = cnode->first->cast<CNodePtr>()->size() - 1;
auto func_graph = std::make_shared<FuncGraph>();
std::vector<AnfNodePtr> outputs;
auto bprop_cut = std::make_shared<Primitive>("bprop_cut");
bprop_cut->set_hook(prim->hook());
auto cell_id = GetValue<std::string>(prim->GetAttr("cell_id"));
if (cell_id != "") {
(void)bprop_cut->AddAttr("cell_hook", MakeValue(true));
(void)bprop_cut->AddAttr("cell_id", MakeValue(cell_id));
}
outputs.push_back(NewValueNode(bprop_cut));
for (size_t i = 0; i < inputs_num; ++i) {
auto param = func_graph->add_parameter();
outputs.push_back(param);
}
auto p1 = func_graph->add_parameter();
auto p2 = func_graph->add_parameter();
outputs.push_back(p1);
outputs.push_back(p2);
func_graph->set_output(func_graph->NewCNode(outputs));
return func_graph;
}
FuncGraphPtr KPrim::FakeBprop(const ValueNodePtr &value_node, const pipeline::ResourceBasePtr &resources) {
auto prim = value_node->value()->cast<PrimitivePtr>();
MS_EXCEPTION_IF_NULL(prim);
......
......@@ -49,9 +49,10 @@ OptimizeIRPassLib::OptimizeIRPassLib() {
arithmetic_simplify_ = MakeSubstitution(ArithmeticSimplify(), "arithmetic_simplify",
{prim::kPrimScalarAdd, prim::kPrimScalarMul, prim::kPrimTensorAdd,
prim::kPrimIdentity, prim::kPrimMomentum, prim::kPrimMul});
special_op_eliminate_ = MakeSubstitution(SpecialOpEliminater(), "special_op_eliminate",
{prim::kPrimInsertGradientOf, prim::kPrimPrintShapeType,
prim::kPrimGetRefKey, prim::kPrimMirror, prim::kPrimVirtualDiv});
special_op_eliminate_ =
MakeSubstitution(SpecialOpEliminater(), "special_op_eliminate",
{prim::kPrimInsertGradientOf, prim::kPrimHookBackward, prim::kPrimPrintShapeType,
prim::kPrimGetRefKey, prim::kPrimMirror, prim::kPrimVirtualDiv});
zero_like_fill_zero_ = MakeSubstitution(ZeroLikeFillZero(), "zero_like_fill_zero", prim::kPrimZerosLikeTensor);
adjust_all_reduce_mul_add_ = MakeSubstitution(AdjustAllReduceMulAdd(), "adjust_all_reduce_mul_add", prim::kPrimAddN);
......
......@@ -35,11 +35,13 @@ class SpecialOpEliminater {
public:
SpecialOpEliminater()
: insert_gradient_of_(prim::kPrimInsertGradientOf),
hook_backward_(prim::kPrimHookBackward),
print_shape_type_(prim::kPrimPrintShapeType),
get_ref_value_(prim::kPrimGetRefValue),
mirror_(prim::kPrimMirror),
virtual_div_(prim::kPrimVirtualDiv) {
eliminaters_.emplace_back(insert_gradient_of_);
eliminaters_.emplace_back(hook_backward_);
eliminaters_.emplace_back(print_shape_type_);
eliminaters_.emplace_back(get_ref_value_);
eliminaters_.emplace_back(mirror_);
......@@ -59,7 +61,7 @@ class SpecialOpEliminater {
}
private:
PrimEliminater insert_gradient_of_, print_shape_type_, get_ref_value_, mirror_, virtual_div_;
PrimEliminater insert_gradient_of_, hook_backward_, print_shape_type_, get_ref_value_, mirror_, virtual_div_;
std::vector<TransformFuncType> eliminaters_{};
};
......
......@@ -30,6 +30,7 @@
#include "operator/composite/composite.h"
#include "ir/func_graph_cloner.h"
#include "utils/symbolic.h"
#include "utils/context/ms_context.h"
#include "debug/trace.h"
namespace mindspore {
......@@ -207,6 +208,35 @@ bool ConvertTensor(const py::object &obj, ValuePtr *const data) {
return true;
}
FuncGraphPtr ConvertToBpropCut(py::object obj) {
std::vector<std::string> results = data_converter::GetObjKey(obj);
std::string obj_key = results[0];
py::function bprop_func = py::getattr(obj, "bprop");
FuncGraphPtr bprop_graph = std::make_shared<FuncGraph>();
std::vector<AnfNodePtr> outputs;
auto fake_bprop = std::make_shared<Primitive>("bprop_cut");
fake_bprop->set_hook(bprop_func);
(void)fake_bprop->AddAttr("bprop", MakeValue(true));
outputs.push_back(NewValueNode(fake_bprop));
py::object code_obj = py::getattr(bprop_func, "__code__");
size_t inputs_num = py::cast<int>(py::getattr(code_obj, "co_argcount")) - 3;
for (size_t i = 0; i < inputs_num; ++i) {
auto param = bprop_graph->add_parameter();
outputs.push_back(param);
}
auto p1 = bprop_graph->add_parameter();
auto p2 = bprop_graph->add_parameter();
outputs.push_back(p1);
outputs.push_back(p2);
bprop_graph->set_output(bprop_graph->NewCNode(outputs));
data_converter::SetObjGraphValue(obj_key, bprop_graph);
return bprop_graph;
}
bool ConvertOtherObj(py::object obj, ValuePtr *const data) {
auto obj_type = data_converter::GetObjType(obj);
MS_LOG(DEBUG) << "Converting the object(" << ((std::string)py::str(obj)) << ") detail type: " << obj_type << " ";
......@@ -238,7 +268,13 @@ bool ConvertOtherObj(py::object obj, ValuePtr *const data) {
}
// if the cell object has specified bprop, it has user-defined bprop function parse and record it
if (py::hasattr(obj, "bprop")) {
FuncGraphPtr bprop_graph = ConvertToFuncGraph(obj, PYTHON_MOD_GET_BPROP_METHOD);
FuncGraphPtr bprop_graph = nullptr;
bool enable_bprop_debug = py::cast<bool>(py::getattr(obj, "bprop_debug"));
if (enable_bprop_debug) {
bprop_graph = ConvertToBpropCut(obj);
} else {
bprop_graph = ConvertToFuncGraph(obj, PYTHON_MOD_GET_BPROP_METHOD);
}
if (bprop_graph != nullptr) {
(void)func_graph->transforms().insert(std::make_pair("bprop", FuncGraphTransform(bprop_graph)));
(void)bprop_graph->transforms().insert(std::make_pair("primal", FuncGraphTransform(func_graph)));
......
......@@ -108,6 +108,7 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() {
{prim::kPrimRelu, {InferImplRelu, true}},
{prim::kPrimZerosLikeTensor, {InferImplZerosLikeTensor, true}},
{prim::kPrimFakeBprop, {InferImplFakeBprop, false}},
{prim::kPrimBpropCut, {InferImplBpropCut, true}},
{prim::kPrimLayerNorm, {InferImplLayerNorm, true}},
{prim::kPrimLayerNormGrad, {InferImplLayerNormGrad, true}},
{prim::kPrimDropoutGenMask, {InferImplDropoutGenMask, true}},
......
......@@ -210,6 +210,8 @@ AbstractBasePtr InferImplZerosLikeTensor(const AnalysisEnginePtr &, const Primit
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplFakeBprop(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplBpropCut(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplLayerNorm(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplLayerNormGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
......
......@@ -64,6 +64,9 @@ LinConvertResult MsBackend::MsConvert(const AnfNodePtrList &lst) {
result.outputs = outputs;
result.graph_id = kInvalidGraphId;
auto graph_id = sess_->CompileGraph(lst, outputs);
if (MsContext::GetInstance()->execution_mode() == kPynativeMode) {
sess_->BuildGraph(graph_id);
}
if (MsContext::GetInstance()->precompile_only()) {
MS_LOG(INFO) << "PrecompileOnly, stop run graph";
return result;
......
......@@ -40,9 +40,10 @@ using MapPrimTypeFuncGraph = std::map<PrimTypePair, FuncGraphPtr>;
using TypedPrimitiveAbstractClosurePtr = std::shared_ptr<abstract::TypedPrimitiveAbstractClosure>;
std::vector<PrimitivePtr> nonlinear_ops = {prim::kPrimReturn, prim::kPrimPartial, prim::kPrimSwitch,
prim::kPrimMakeTuple};
prim::kPrimMakeTuple, prim::kPrimBpropCut};
const std::vector<PrimitivePtr> &GetMsNonlinearOps() {
static const std::vector<PrimitivePtr> ms_nonlinear_ops = {prim::kPrimReturn, prim::kPrimPartial, prim::kPrimSwitch};
static const std::vector<PrimitivePtr> ms_nonlinear_ops = {prim::kPrimReturn, prim::kPrimPartial, prim::kPrimSwitch,
prim::kPrimBpropCut};
return ms_nonlinear_ops;
}
......@@ -646,8 +647,13 @@ BackendPtr CreateBackend() {
auto backend = std::make_shared<MsBackend>(name, target, device_id);
std::string device_target = MsContext::GetInstance()->device_target();
if (device_target == kAscendDevice) {
backend->set_is_multi_graph_sink(true);
context_ptr->set_is_multi_graph_sink(true);
if (MsContext::GetInstance()->execution_mode() == kPynativeMode) {
backend->set_is_multi_graph_sink(false);
context_ptr->set_is_multi_graph_sink(false);
} else {
backend->set_is_multi_graph_sink(true);
context_ptr->set_is_multi_graph_sink(true);
}
}
return backend;
}
......
......@@ -587,15 +587,65 @@ void FinalVM::InstPushPrim(const VectorRef &args) {
VectorRef tuple;
auto prim = utils::cast<PrimitivePtr>(args[0]);
for (size_t i = 1; i < args.size(); ++i) {
auto index = utils::cast<int>(args[1]);
auto index = utils::cast<int>(args[i]);
tuple.push_back(Ref(index));
}
auto outs = RunOperation(prim, tuple);
Push(outs);
if (prim->name() == "bprop_cut") {
auto outs = RunHook(prim, tuple);
Push(outs);
} else {
auto outs = RunOperation(prim, tuple);
Push(outs);
}
MS_LOG(DEBUG) << "End";
}
BaseRef FinalVM::RunHook(const PrimitivePtr &prim, const VectorRef &args) {
py::tuple py_args = py::tuple(args.size());
MS_LOG(DEBUG) << "input for operation:";
size_t i = 0;
for (auto &arg : args) {
py_args[i] = BaseRefToPyData(arg);
MS_LOG(DEBUG) << "arg: " << i << ":";
i++;
}
py::object obj;
bool is_bprop = prim->HasAttr("bprop");
if (is_bprop) {
py::function fn_bprop = prim->hook();
obj = fn_bprop(*py_args);
return obj;
}
bool is_cell = prim->HasAttr("cell_hook");
if (is_cell) {
std::string cell_id = GetValue<std::string>(prim->GetAttr("cell_id"));
if (_hook_grad.find(cell_id) != _hook_grad.end()) {
py::tuple hook_args = py::tuple(3);
hook_args[0] = cell_id;
hook_args[1] = _hook_grad[cell_id];
hook_args[2] = py_args[2];
py::function fn_hook = prim->hook();
obj = fn_hook(*hook_args);
if (py::isinstance<py::none>(obj)) {
obj = py_args[2];
}
_hook_grad.erase(cell_id);
} else {
_hook_grad[cell_id] = py_args[2];
obj = py_args[2];
}
} else {
py::function fn_hook = prim->hook();
obj = fn_hook(py_args[2]);
if (py::isinstance<py::none>(obj)) {
obj = py_args[2];
}
}
obj = py::make_tuple(obj);
return obj;
}
} // namespace compile
} // namespace mindspore
......@@ -115,6 +115,7 @@ class FinalVM {
void InstPushPrim(const VectorRef &args);
void InstSwitchReturn(const VectorRef &args);
void set_insts(const InstSet &value) { insts_ = value; }
BaseRef RunHook(const PrimitivePtr &prim, const VectorRef &args);
protected:
BaseRef Ref(int i);
......@@ -156,6 +157,7 @@ class FinalVM {
{Instruction::kPrim, [this](const VectorRef &args) { InstPushPrim(args); }},
{Instruction::kSwitchReturn, [this](const VectorRef &args) { InstSwitchReturn(args); }},
};
std::map<std::string, py::object> _hook_grad;
};
using FinalVMPtr = std::shared_ptr<FinalVM>;
......
......@@ -24,6 +24,7 @@ from .._checkparam import _check_str_by_regular
from ..common.parameter import Parameter, ParameterTuple
from .._c_expression import init_backend
from ..ops.primitive import Primitive
from ..ops.operations import HookBackward
from ..parallel._tensor import _load_tensor_by_layout
from ..common.tensor import Tensor
......@@ -75,6 +76,9 @@ class Cell:
self._parallel_inputs_run = None
if flags:
self.add_flags(**flags)
self._backward_hook = None
self._enable_hook = False
self._bprop_debug = False
@property
def create_time(self):
......@@ -91,6 +95,16 @@ class Cell:
"""
return self._param_prefix
@property
def bprop_debug(self):
return self._bprop_debug
@bprop_debug.setter
def bprop_debug(self, value):
if not isinstance(value, bool):
raise TypeError("'bprop debug' value must be bool type.")
self._bprop_debug = value
def update_cell_prefix(self):
"""
Update the all child cells' self.param_prefix.
......@@ -728,3 +742,25 @@ class Cell:
self._auto_parallel_mode = True
self.add_flags(auto_parallel=True)
self._get_construct_inputs_number_and_name()
def _hook_construct(self, inputs):
"""Hook construct method to replace original construct method when hook function enabled."""
inputs = self._backward_hook(inputs)
inputs = self.construct(inputs)
outputs = self._backward_hook(inputs)
return outputs
@property
def enable_hook(self):
"""Whether the cell register hook function"""
return self._enable_hook
def register_backward_hook(self, fn):
"""
Set the cell backward hook function.
Args:
fn (function): Specifies the hook function with grad as input.
"""
self._backward_hook = HookBackward(fn, str(id(self)))
self._enable_hook = True
......@@ -33,7 +33,7 @@ from .array_ops import (Argmax, Argmin, Cast, Concat, Pack, Unpack,
from .comm_ops import (AllGather, AllReduce, _AlltoAll, ReduceScatter, Broadcast,
_MirrorOperator, ReduceOp, _VirtualDataset,
_VirtualDiv, _GetTensorSlice)
from .debug_ops import (ImageSummary, InsertGradientOf, ScalarSummary,
from .debug_ops import (ImageSummary, InsertGradientOf, HookBackward, ScalarSummary,
TensorSummary, HistogramSummary, Print)
from .control_ops import ControlDepend, GeSwitch, Merge
from .inner_ops import ScalarCast
......@@ -155,6 +155,7 @@ __all__ = [
'HistogramSummary',
"Print",
'InsertGradientOf',
'HookBackward',
'InvertPermutation',
'Shape',
'DropoutDoMask',
......
......@@ -14,6 +14,7 @@
# ============================================================================
"""debug_ops"""
from types import FunctionType
from ..._checkparam import Validator as validator
from ...common import dtype as mstype
from ..primitive import Primitive, prim_attr_register, PrimitiveWithInfer
......@@ -193,6 +194,65 @@ class InsertGradientOf(PrimitiveWithInfer):
return x_type
class HookBackward(PrimitiveWithInfer):
"""
Used as tag to hook gradient in intermediate variables.
Note:
The hook function should have one input of gradient of the variable.
hook function will be executed in python environment, while callback
of InsertGradientOf will be parsed and added to the graph.
Args:
hook_fn (Function): Python function. hook function.
Inputs:
- **inputs** (Tensor) - The variable to hook.
Examples:
>>> def hook_fn(grad_out):
>>> print(grad_out)
>>>
>>> hook = P.HookBackward(hook_fn)
>>>
>>> def hook_test(x, y):
>>> z = x * y
>>> z = hook(z)
>>> z = z * y
>>> return z
>>>
>>> def backward(x, y):
>>> return C.grad_all(hook_test)(x, y)
>>>
>>> backward(1, 2)
"""
def __init__(self, hook_fn, cell_id=""):
super(HookBackward, self).__init__(self.__class__.__name__)
self.add_prim_attr("cell_id", cell_id)
self.init_attrs["cell_id"] = cell_id
if not isinstance(hook_fn, FunctionType):
raise TypeError("Hook function should be python function type.")
self.register_hook(hook_fn)
self.cell_id = cell_id
def __call__(self, *inputs):
"""run in PyNative mode."""
if len(inputs) == 1:
return inputs[0]
return inputs
def infer_shape(self, *inputs_shape):
if len(inputs_shape) == 1:
return inputs_shape[0]
return inputs_shape
def infer_dtype(self, *inputs_type):
if len(inputs_type) == 1:
return inputs_type[0]
return inputs_type
class Print(PrimitiveWithInfer):
"""
Output tensor or string to stdout.
......
import numpy as np
import mindspore.nn as nn
import mindspore.ops.operations as P
from mindspore import context
from mindspore.ops import composite as C
from mindspore.common import dtype as mstype
from mindspore import context, Tensor, ParameterTuple
from mindspore.common.initializer import TruncatedNormal
from mindspore.nn import Dense, WithLossCell, SoftmaxCrossEntropyWithLogits, Momentum
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
def conv(in_channels, out_channels, kernel_size, stride=1, padding=0):
"""weight initial for conv layer"""
weight = weight_variable()
return nn.Conv2d(in_channels, out_channels,
kernel_size=kernel_size, stride=stride, padding=padding,
weight_init=weight, has_bias=False, pad_mode="valid")
def fc_with_initialize(input_channels, out_channels):
"""weight initial for fc layer"""
weight = weight_variable()
bias = weight_variable()
return nn.Dense(input_channels, out_channels, weight, bias)
def weight_variable():
"""weight initial"""
return TruncatedNormal(0.02)
def cell_hook_function(cell_id, grad_input, grad_output):
print(cell_id)
assert(grad_output.asnumpy().shape == (32, 6, 14, 14))
assert(grad_input.asnumpy().shape == (32, 16, 10, 10))
def var_hook_function(grad_out):
print("grad:", grad_out)
assert(grad_out.asnumpy().shape == (32, 120))
class LeNet5(nn.Cell):
"""
Lenet network
Args:
num_class (int): Num classes. Default: 10.
Returns:
Tensor, output tensor
Examples:
>>> LeNet(num_class=10)
"""
def __init__(self, num_class=10):
super(LeNet5, self).__init__()
self.num_class = num_class
self.batch_size = 32
self.conv1 = conv(1, 6, 5)
self.conv2 = conv(6, 16, 5)
self.conv2.register_backward_hook(cell_hook_function)
self.fc1 = fc_with_initialize(16 * 5 * 5, 120)
self.fc2 = fc_with_initialize(120, 84)
self.fc3 = fc_with_initialize(84, self.num_class)
self.relu = nn.ReLU()
self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)
self.reshape = P.Reshape()
self.hook = P.HookBackward(var_hook_function)
def construct(self, x):
x = self.conv1(x)
x = self.relu(x)
x = self.max_pool2d(x)
x = self.conv2(x)
x = self.relu(x)
x = self.max_pool2d(x)
x = self.reshape(x, (self.batch_size, -1))
x = self.fc1(x)
x = self.hook(x)
x = self.relu(x)
x = self.fc2(x)
x = self.relu(x)
x = self.fc3(x)
return x
class GradWrap(nn.Cell):
""" GradWrap definition """
def __init__(self, network):
super(GradWrap, self).__init__(auto_prefix=False)
self.network = network
self.weights = ParameterTuple(filter(lambda x: x.requires_grad, network.get_parameters()))
def construct(self, x, label):
weights = self.weights
return C.GradOperation('get_by_list', get_by_list=True)(self.network, weights)(x, label)
def test_hook():
net = LeNet5()
optimizer = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), 0.1, 0.9)
criterion = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=False)
net_with_criterion = WithLossCell(net, criterion)
train_network = GradWrap(net_with_criterion)
train_network.set_train()
input_data = Tensor(np.ones([net.batch_size, 1, 32, 32]).astype(np.float32) * 0.01)
label = Tensor(np.ones([net.batch_size, net.num_class]).astype(np.float32))
output = net(Tensor(input_data))
loss_output = criterion(output, label)
grads = train_network(input_data, label)
success = optimizer(grads)
print(loss_output.asnumpy().shape)
class MulAdd(nn.Cell):
def __init__(self):
super(MulAdd, self).__init__()
def construct(self, x, y):
return 2 * x + y
def bprop(self, x, y, out, dout):
assert(x == 1)
assert(y == 2)
assert(out == 4)
assert(dout == 1)
return 3 * dout, 2 * y
def test_custom_bprop():
mul_add = MulAdd()
mul_add.bprop_debug = True
assert C.grad_all(mul_add)(1, 2) == (3, 4)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册