use_structured_name(bool, optional) : If true, use structured name as key, otherwise, use parameter or buffer name as key.
use_structured_name(bool, optional) : If true, use structured name as key, otherwise, use parameter or buffer name as key.
Default: True
Default: True
Returns:
Returns:
None
missing_keys(list):A list of str containing the missing keys
unexpected_keys(list):A list of str containing the unexpected keys
Examples:
Examples:
.. code-block:: python
.. code-block:: python
...
@@ -1615,15 +1616,20 @@ class Layer:
...
@@ -1615,15 +1616,20 @@ class Layer:
emb.set_state_dict(para_state_dict)
emb.set_state_dict(para_state_dict)
'''
'''
missing_keys=[]
match_keys=set()
unexpected_keys=[]
def_check_match(key,param):
def_check_match(key,param):
state=state_dict.get(key,None)
state=state_dict.get(key,None)
ifstateisNone:
ifstateisNone:
missing_keys.append(key)
raiseValueError(
raiseValueError(
"{} is not found in the provided dict.".format(key)
"{} is not found in the provided dict.".format(key)
)
)
ifisinstance(state,dict)orisinstance(state,list):
ifisinstance(state,dict)orisinstance(state,list):
iflen(state)!=len(param):
iflen(state)!=len(param):
missing_keys.append(key)
raiseValueError(
raiseValueError(
"{} receieves the length of {}, "
"{} receieves the length of {}, "
"but the expected shape is {}".format(
"but the expected shape is {}".format(
...
@@ -1631,6 +1637,7 @@ class Layer:
...
@@ -1631,6 +1637,7 @@ class Layer:
)
)
)
)
else:
else:
match_keys.add(key)
returnparam,state
returnparam,state
else:
else:
state_shape=(
state_shape=(
...
@@ -1640,11 +1647,13 @@ class Layer:
...
@@ -1640,11 +1647,13 @@ class Layer:
)
)
iflist(state_shape)!=list(param.shape):
iflist(state_shape)!=list(param.shape):
missing_keys.append(key)
raiseValueError(
raiseValueError(
"{} receives a shape {}, but the expected shape is {}.".format(
"{} receives a shape {}, but the expected shape is {}.".format(
key,list(state_shape),list(param.shape)
key,list(state_shape),list(param.shape)
)
)
)
)
match_keys.add(key)
returnparam,state
returnparam,state
matched_param_state=[]
matched_param_state=[]
...
@@ -1655,7 +1664,9 @@ class Layer:
...
@@ -1655,7 +1664,9 @@ class Layer:
matched_param_state.append(match_res)
matched_param_state.append(match_res)
exceptValueErroraserr:
exceptValueErroraserr:
warnings.warn(("Skip loading for {}. ".format(key)+str(err)))
warnings.warn(("Skip loading for {}. ".format(key)+str(err)))
forkeyinstate_dict.keys():
ifkeynotinmatch_keys:
unexpected_keys.append(key)
if_non_static_mode():
if_non_static_mode():
forparam,stateinmatched_param_state:
forparam,stateinmatched_param_state:
param.set_value(state)
param.set_value(state)
...
@@ -1693,6 +1704,8 @@ class Layer:
...
@@ -1693,6 +1704,8 @@ class Layer:
"This error might happens in dy2static, while calling 'set_state_dict' dynamicly in 'forward', which is not supported. If you only need call 'set_state_dict' once, move it to '__init__'."
"This error might happens in dy2static, while calling 'set_state_dict' dynamicly in 'forward', which is not supported. If you only need call 'set_state_dict' once, move it to '__init__'."
)
)
returnmissing_keys,unexpected_keys
defto(self,device=None,dtype=None,blocking=None):
defto(self,device=None,dtype=None,blocking=None):
'''
'''
Cast the parameters and buffers of Layer by the give device, dtype and blocking.
Cast the parameters and buffers of Layer by the give device, dtype and blocking.