未验证 提交 934d9986 编写于 作者: T tangwei12 提交者: GitHub

add selected rows supported in framework (#21808)

* add selected rows supported in framework
上级 855ed5fb
...@@ -568,7 +568,6 @@ def _varbase_creator(type=core.VarDesc.VarType.LOD_TENSOR, ...@@ -568,7 +568,6 @@ def _varbase_creator(type=core.VarDesc.VarType.LOD_TENSOR,
dtype=None, dtype=None,
persistable=None, persistable=None,
**kwargs): **kwargs):
if dtype is not None: if dtype is not None:
if not isinstance(dtype, core.VarDesc.VarType): if not isinstance(dtype, core.VarDesc.VarType):
dtype = convert_np_dtype_to_dtype_(dtype) dtype = convert_np_dtype_to_dtype_(dtype)
...@@ -1398,6 +1397,10 @@ class Variable(object): ...@@ -1398,6 +1397,10 @@ class Variable(object):
# TODO(minqiyang): Support lod_level in dygraph mode # TODO(minqiyang): Support lod_level in dygraph mode
if in_dygraph_mode(): if in_dygraph_mode():
raise Exception("Dygraph model DO NOT supprt lod") raise Exception("Dygraph model DO NOT supprt lod")
if self.type == core.VarDesc.VarType.SELECTED_ROWS:
raise Exception("SelectedRows DO NOT supprt lod")
return self.desc.lod_level() return self.desc.lod_level()
@property @property
...@@ -2445,7 +2448,7 @@ class Block(object): ...@@ -2445,7 +2448,7 @@ class Block(object):
" is inited by multiple init ops " + str( " is inited by multiple init ops " + str(
init_ops)) init_ops))
elif init_ops_len == 1: elif init_ops_len == 1:
#TODO already inited, do nothing, should log a warning # TODO already inited, do nothing, should log a warning
pass pass
else: else:
initializer(param, self) initializer(param, self)
...@@ -4525,7 +4528,12 @@ class Parameter(Variable): ...@@ -4525,7 +4528,12 @@ class Parameter(Variable):
be applied on this parameter. be applied on this parameter.
""" """
def __init__(self, block, shape, dtype, **kwargs): def __init__(self,
block,
shape,
dtype,
type=core.VarDesc.VarType.LOD_TENSOR,
**kwargs):
if shape is None: if shape is None:
raise ValueError("The shape of Parameter should not be None") raise ValueError("The shape of Parameter should not be None")
if dtype is None: if dtype is None:
...@@ -4542,7 +4550,13 @@ class Parameter(Variable): ...@@ -4542,7 +4550,13 @@ class Parameter(Variable):
% list(shape)) % list(shape))
Variable.__init__( Variable.__init__(
self, block, persistable=True, shape=shape, dtype=dtype, **kwargs) self,
block,
persistable=True,
shape=shape,
dtype=dtype,
type=type,
**kwargs)
self.trainable = kwargs.get('trainable', True) self.trainable = kwargs.get('trainable', True)
self.optimize_attr = kwargs.get('optimize_attr', {'learning_rate': 1.0}) self.optimize_attr = kwargs.get('optimize_attr', {'learning_rate': 1.0})
...@@ -4660,7 +4674,7 @@ class ParamBase(core.VarBase): ...@@ -4660,7 +4674,7 @@ class ParamBase(core.VarBase):
self.is_distributed = False self.is_distributed = False
#self.block = default_main_program().global_block() # self.block = default_main_program().global_block()
_dygraph_tracer().trace_var(name, self) _dygraph_tracer().trace_var(name, self)
......
...@@ -280,7 +280,8 @@ class LayerHelperBase(object): ...@@ -280,7 +280,8 @@ class LayerHelperBase(object):
dtype, dtype,
is_bias=False, is_bias=False,
default_initializer=None, default_initializer=None,
stop_gradient=False): stop_gradient=False,
type=core.VarDesc.VarType.LOD_TENSOR):
"""Create parameters for this layers. """Create parameters for this layers.
Args: Args:
...@@ -334,15 +335,17 @@ class LayerHelperBase(object): ...@@ -334,15 +335,17 @@ class LayerHelperBase(object):
return self.main_program.global_block().create_parameter( return self.main_program.global_block().create_parameter(
dtype=dtype, dtype=dtype,
shape=shape, shape=shape,
type=type,
stop_gradient=stop_gradient, stop_gradient=stop_gradient,
**attr._to_kwargs(with_initializer=True)) **attr._to_kwargs(with_initializer=True))
else: else:
self.startup_program.global_block().create_parameter( self.startup_program.global_block().create_parameter(
dtype=dtype, dtype=dtype,
shape=shape, shape=shape,
type=type,
**attr._to_kwargs(with_initializer=True)) **attr._to_kwargs(with_initializer=True))
return self.main_program.global_block().create_parameter( return self.main_program.global_block().create_parameter(
dtype=dtype, shape=shape, **attr._to_kwargs()) dtype=dtype, shape=shape, type=type, **attr._to_kwargs())
def create_variable_for_type_inference(self, dtype, stop_gradient=False): def create_variable_for_type_inference(self, dtype, stop_gradient=False):
"""Create a temporary variable that should be type inferred layer. """Create a temporary variable that should be type inferred layer.
......
...@@ -324,7 +324,8 @@ class Optimizer(object): ...@@ -324,7 +324,8 @@ class Optimizer(object):
param, param,
dtype=None, dtype=None,
fill_value=0.0, fill_value=0.0,
shape=None): shape=None,
type=None):
"""Utility function to add an accumulator for a parameter """Utility function to add an accumulator for a parameter
Args: Args:
...@@ -354,7 +355,7 @@ class Optimizer(object): ...@@ -354,7 +355,7 @@ class Optimizer(object):
name=var_name, name=var_name,
persistable=True, persistable=True,
dtype=dtype or param.dtype, dtype=dtype or param.dtype,
type=param.type, type=param.type if type is None else type,
shape=shape, shape=shape,
belong_to_optimizer=True) belong_to_optimizer=True)
self.helper.set_variable_initializer( self.helper.set_variable_initializer(
...@@ -1635,13 +1636,15 @@ class AdamOptimizer(Optimizer): ...@@ -1635,13 +1636,15 @@ class AdamOptimizer(Optimizer):
param=p, param=p,
fill_value=0.9 if isinstance(self._beta1, Variable) \ fill_value=0.9 if isinstance(self._beta1, Variable) \
else self._beta1, else self._beta1,
shape=[1]) shape=[1],
type=core.VarDesc.VarType.LOD_TENSOR)
self._add_accumulator( self._add_accumulator(
name=self._beta2_pow_acc_str, name=self._beta2_pow_acc_str,
param=p, param=p,
fill_value=0.999 if isinstance(self._beta2, Variable) \ fill_value=0.999 if isinstance(self._beta2, Variable) \
else self._beta2, else self._beta2,
shape=[1]) shape=[1],
type=core.VarDesc.VarType.LOD_TENSOR)
def _append_optimize_op(self, block, param_and_grad): def _append_optimize_op(self, block, param_and_grad):
assert isinstance(block, framework.Block) assert isinstance(block, framework.Block)
......
...@@ -206,6 +206,21 @@ class TestVariable(unittest.TestCase): ...@@ -206,6 +206,21 @@ class TestVariable(unittest.TestCase):
self.assertIsNone(var.dtype) self.assertIsNone(var.dtype)
self.assertIsNone(var.type) self.assertIsNone(var.type)
def test_create_selected_rows(self):
b = default_main_program().current_block()
var = b.create_var(
name="var",
shape=[1, 1],
dtype="float32",
type=fluid.core.VarDesc.VarType.SELECTED_ROWS,
persistable=True)
def _test():
var.lod_level()
self.assertRaises(Exception, _test)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册