未验证 提交 934d9986 编写于 作者: T tangwei12 提交者: GitHub

add selected rows supported in framework (#21808)

* add selected rows supported in framework
上级 855ed5fb
......@@ -568,7 +568,6 @@ def _varbase_creator(type=core.VarDesc.VarType.LOD_TENSOR,
dtype=None,
persistable=None,
**kwargs):
if dtype is not None:
if not isinstance(dtype, core.VarDesc.VarType):
dtype = convert_np_dtype_to_dtype_(dtype)
......@@ -1398,6 +1397,10 @@ class Variable(object):
# TODO(minqiyang): Support lod_level in dygraph mode
if in_dygraph_mode():
raise Exception("Dygraph model DO NOT supprt lod")
if self.type == core.VarDesc.VarType.SELECTED_ROWS:
raise Exception("SelectedRows DO NOT supprt lod")
return self.desc.lod_level()
@property
......@@ -2445,7 +2448,7 @@ class Block(object):
" is inited by multiple init ops " + str(
init_ops))
elif init_ops_len == 1:
#TODO already inited, do nothing, should log a warning
# TODO already inited, do nothing, should log a warning
pass
else:
initializer(param, self)
......@@ -4525,7 +4528,12 @@ class Parameter(Variable):
be applied on this parameter.
"""
def __init__(self, block, shape, dtype, **kwargs):
def __init__(self,
block,
shape,
dtype,
type=core.VarDesc.VarType.LOD_TENSOR,
**kwargs):
if shape is None:
raise ValueError("The shape of Parameter should not be None")
if dtype is None:
......@@ -4542,7 +4550,13 @@ class Parameter(Variable):
% list(shape))
Variable.__init__(
self, block, persistable=True, shape=shape, dtype=dtype, **kwargs)
self,
block,
persistable=True,
shape=shape,
dtype=dtype,
type=type,
**kwargs)
self.trainable = kwargs.get('trainable', True)
self.optimize_attr = kwargs.get('optimize_attr', {'learning_rate': 1.0})
......@@ -4660,7 +4674,7 @@ class ParamBase(core.VarBase):
self.is_distributed = False
#self.block = default_main_program().global_block()
# self.block = default_main_program().global_block()
_dygraph_tracer().trace_var(name, self)
......
......@@ -280,7 +280,8 @@ class LayerHelperBase(object):
dtype,
is_bias=False,
default_initializer=None,
stop_gradient=False):
stop_gradient=False,
type=core.VarDesc.VarType.LOD_TENSOR):
"""Create parameters for this layers.
Args:
......@@ -334,15 +335,17 @@ class LayerHelperBase(object):
return self.main_program.global_block().create_parameter(
dtype=dtype,
shape=shape,
type=type,
stop_gradient=stop_gradient,
**attr._to_kwargs(with_initializer=True))
else:
self.startup_program.global_block().create_parameter(
dtype=dtype,
shape=shape,
type=type,
**attr._to_kwargs(with_initializer=True))
return self.main_program.global_block().create_parameter(
dtype=dtype, shape=shape, **attr._to_kwargs())
dtype=dtype, shape=shape, type=type, **attr._to_kwargs())
def create_variable_for_type_inference(self, dtype, stop_gradient=False):
"""Create a temporary variable that should be type inferred layer.
......
......@@ -324,7 +324,8 @@ class Optimizer(object):
param,
dtype=None,
fill_value=0.0,
shape=None):
shape=None,
type=None):
"""Utility function to add an accumulator for a parameter
Args:
......@@ -354,7 +355,7 @@ class Optimizer(object):
name=var_name,
persistable=True,
dtype=dtype or param.dtype,
type=param.type,
type=param.type if type is None else type,
shape=shape,
belong_to_optimizer=True)
self.helper.set_variable_initializer(
......@@ -1635,13 +1636,15 @@ class AdamOptimizer(Optimizer):
param=p,
fill_value=0.9 if isinstance(self._beta1, Variable) \
else self._beta1,
shape=[1])
shape=[1],
type=core.VarDesc.VarType.LOD_TENSOR)
self._add_accumulator(
name=self._beta2_pow_acc_str,
param=p,
fill_value=0.999 if isinstance(self._beta2, Variable) \
else self._beta2,
shape=[1])
shape=[1],
type=core.VarDesc.VarType.LOD_TENSOR)
def _append_optimize_op(self, block, param_and_grad):
assert isinstance(block, framework.Block)
......
......@@ -206,6 +206,21 @@ class TestVariable(unittest.TestCase):
self.assertIsNone(var.dtype)
self.assertIsNone(var.type)
def test_create_selected_rows(self):
b = default_main_program().current_block()
var = b.create_var(
name="var",
shape=[1, 1],
dtype="float32",
type=fluid.core.VarDesc.VarType.SELECTED_ROWS,
persistable=True)
def _test():
var.lod_level()
self.assertRaises(Exception, _test)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册