Created by: chenwhql
PR types
Function optimization
PR changes
APIs
Describe
Use structured name as key in loaded state_dict, instead use parameter.name
original example:
import paddle
import paddle.nn as nn
IMAGE_SIZE = 784
CLASS_NUM = 10
class LinearNet(nn.Layer):
def __init__(self):
super(LinearNet, self).__init__()
self._linear = nn.Linear(IMAGE_SIZE, CLASS_NUM)
@paddle.jit.to_static
def forward(self, x):
return self._linear(x)
# enable dygraph mode
paddle.disable_static()
# create network
layer = LinearNet()
# load
model_path = "linear.example.model"
state_dict, _ = paddle.load(model_path)
# inference
layer.set_state_dict(state_dict, use_structured_name=False)
layer.eval()
x = paddle.randn([1, IMAGE_SIZE], 'float32')
pred = layer(x)
new example change:
# inference
layer.set_state_dict(state_dict)
layer.eval()
x = paddle.randn([1, IMAGE_SIZE], 'float32')
pred = layer(x)
Pay attention to the use of Layer.set_state_dict
, no need to configure use_structured_name=False