提交 63d0d2e0 编写于 作者: W wuzewu

fix type error bug when deserialize parameter attribute

上级 c0604646
......@@ -17,6 +17,7 @@ from __future__ import division
from __future__ import print_function
from paddle_hub import module_desc_pb2
from paddle_hub.utils import from_pyobj_to_flexible_data, from_flexible_data_to_pyobj
from paddle_hub.logger import logger
import paddle
import paddle.fluid as fluid
......@@ -81,35 +82,37 @@ def from_flexible_data_to_param(flexible_data):
flexible_data.map.data['optimize_attr'])
if flexible_data.map.data['regularizer'].type != module_desc_pb2.NONE:
regularizer_type = flexible_data.map.data['regularizer'].name
regularization_coeff = flexible_data.map.data[
'regularizer'].object.data['_regularization_coeff '].f
regularization_coeff = from_flexible_data_to_pyobj(
flexible_data.map.data['regularizer'].object.
data['_regularization_coeff'])
param['regularizer'] = eval(
"fluid.regularizer.%s(regularization_coeff = %f)" %
(regularizer_type, regularization_coeff))
if flexible_data.map.data['regularizer'].type != module_desc_pb2.NONE:
if flexible_data.map.data['gradient_clip_attr'].type != module_desc_pb2.NONE:
clip_type = flexible_data.map.data['gradient_clip_attr'].name
if clip_type == "ErrorClipByValue" or clip_type == "GradientClipByValue":
max = flexible_data.map.data[
'regularizer'].name, flexible_data.map.data[
'gradient_clip_attr'].object.data['max'].f
min = flexible_data.map.data[
'regularizer'].name, flexible_data.map.data[
'gradient_clip_attr'].object.data['min'].f
max = from_flexible_data_to_pyobj(
flexible_data.map.data['gradient_clip_attr'].object.data['max'])
min = from_flexible_data_to_pyobj(
flexible_data.map.data['gradient_clip_attr'].object.data['min'])
param['gradient_clip_attr'] = eval(
"fluid.clip.%s(max = %f, min = %f)" % (clip_type, max, min))
if clip_type == "GradientClipByNorm":
clip_norm = flexible_data.map.data[
'gradient_clip_attr'].object.data['clip_norm'].f
clip_norm = from_flexible_data_to_pyobj(
flexible_data.map.data['gradient_clip_attr'].object.
data['clip_norm'])
param['gradient_clip_attr'] = eval(
"fluid.clip.%s(clip_norm = %f)" % (clip_type, clip_norm))
if clip_type == "GradientClipByGlobalNorm":
clip_norm = flexible_data.map.data[
'gradient_clip_attr'].object.data['clip_norm'].f
group_name = flexible_data.map.data[
'gradient_clip_attr'].object.data['group_name'].f
clip_norm = from_flexible_data_to_pyobj(
flexible_data.map.data['gradient_clip_attr'].object.
data['clip_norm'])
group_name = from_flexible_data_to_pyobj(
flexible_data.map.data['gradient_clip_attr'].object.
data['group_name'])
param['gradient_clip_attr'] = eval(
"fluid.clip.%s(clip_norm = %f, group_name = %f)" %
"fluid.clip.%s(clip_norm = %f, group_name = \"%s\")" %
(clip_type, clip_norm, group_name))
return param
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册