未验证 提交 d4bdbf8c 编写于 作者: L Leo Chen 提交者: GitHub

Polish nn code, test=develop (#22237)

* refine code, test=develop

* reuse contain_var, test=develop
上级 efcdeb51
...@@ -455,14 +455,14 @@ def name_scope(prefix=None): ...@@ -455,14 +455,14 @@ def name_scope(prefix=None):
""" """
# TODO(panyx0718): Only [0-9a-z]. # TODO(panyx0718): Only [0-9a-z].
# in dygraph we don't need namescope since it will cause mem leak # in dygraph we don't need namescope since it will cause mem leak
if not in_dygraph_mode(): if in_dygraph_mode():
yield
else:
assert prefix, "namescope prefix cannot be empty." assert prefix, "namescope prefix cannot be empty."
global _name_scope global _name_scope
_name_scope = _name_scope.child(prefix) _name_scope = _name_scope.child(prefix)
yield yield
_name_scope = _name_scope.parent() _name_scope = _name_scope.parent()
else:
yield
def _full_name_scope(): def _full_name_scope():
...@@ -715,10 +715,9 @@ def _getitem_impl_(var, item): ...@@ -715,10 +715,9 @@ def _getitem_impl_(var, item):
if (use_strided_slice == True): if (use_strided_slice == True):
attrs['strides'] = [] attrs['strides'] = []
infer_flags = list(1 for i in range(len(slice_axis))) infer_flags = list(1 for i in range(len(slice_axis)))
# starts # starts
if not contain_var(slice_start): if contain_var(slice_start):
attrs['starts'] = slice_start
else:
inputs['StartsTensorList'] = get_new_list_tensor(slice_start) inputs['StartsTensorList'] = get_new_list_tensor(slice_start)
for i, dim in enumerate(slice_start): for i, dim in enumerate(slice_start):
if isinstance(dim, Variable): if isinstance(dim, Variable):
...@@ -726,10 +725,11 @@ def _getitem_impl_(var, item): ...@@ -726,10 +725,11 @@ def _getitem_impl_(var, item):
infer_flags[i] = -1 infer_flags[i] = -1
else: else:
attrs['starts'].append(dim) attrs['starts'].append(dim)
# ends
if not contain_var(slice_end):
attrs['ends'] = slice_end
else: else:
attrs['starts'] = slice_start
# ends
if contain_var(slice_end):
inputs['EndsTensorList'] = get_new_list_tensor(slice_end) inputs['EndsTensorList'] = get_new_list_tensor(slice_end)
for i, dim in enumerate(slice_end): for i, dim in enumerate(slice_end):
if isinstance(dim, Variable): if isinstance(dim, Variable):
...@@ -737,11 +737,12 @@ def _getitem_impl_(var, item): ...@@ -737,11 +737,12 @@ def _getitem_impl_(var, item):
infer_flags[i] = -1 infer_flags[i] = -1
else: else:
attrs['ends'].append(dim) attrs['ends'].append(dim)
else:
attrs['ends'] = slice_end
# strides # strides
if use_strided_slice == True: if use_strided_slice == True:
if not contain_var(slice_step): if contain_var(slice_step):
attrs['strides'] = slice_step
else:
inputs['StridesTensorList'] = get_new_list_tensor(slice_step) inputs['StridesTensorList'] = get_new_list_tensor(slice_step)
for i, dim in enumerate(slice_step): for i, dim in enumerate(slice_step):
if isinstance(dim, Variable): if isinstance(dim, Variable):
...@@ -749,6 +750,8 @@ def _getitem_impl_(var, item): ...@@ -749,6 +750,8 @@ def _getitem_impl_(var, item):
infer_flags[i] = -1 infer_flags[i] = -1
else: else:
attrs['strides'].append(dim) attrs['strides'].append(dim)
else:
attrs['strides'] = slice_step
# infer_flags # infer_flags
attrs['infer_flags'] = infer_flags attrs['infer_flags'] = infer_flags
...@@ -2344,12 +2347,12 @@ class Block(object): ...@@ -2344,12 +2347,12 @@ class Block(object):
if isinstance(item[1], Parameter)) if isinstance(item[1], Parameter))
def create_var(self, *args, **kwargs): def create_var(self, *args, **kwargs):
if not in_dygraph_mode(): if in_dygraph_mode():
var = _varbase_creator(*args, **kwargs)
else:
var = Variable(block=self, *args, **kwargs) var = Variable(block=self, *args, **kwargs)
if 'initializer' in kwargs: if 'initializer' in kwargs:
kwargs['initializer'](var, self) kwargs['initializer'](var, self)
else:
var = _varbase_creator(*args, **kwargs)
return var return var
def has_var(self, name): def has_var(self, name):
...@@ -2396,9 +2399,8 @@ class Block(object): ...@@ -2396,9 +2399,8 @@ class Block(object):
# NOTE: v is destroyed by C++ after calling _rename_var. # NOTE: v is destroyed by C++ after calling _rename_var.
d = self.desc.find_var(cpt.to_bytes(new_name)) d = self.desc.find_var(cpt.to_bytes(new_name))
if var_type == "Parameter": if var_type == "Parameter":
if not in_dygraph_mode(): if in_dygraph_mode():
var = Parameter( var = ParamBase(
self,
d.shape(), d.shape(),
d.dtype(), d.dtype(),
type=orig_var_type, type=orig_var_type,
...@@ -2410,7 +2412,8 @@ class Block(object): ...@@ -2410,7 +2412,8 @@ class Block(object):
gradient_clip_attr=gradient_clip_attr, gradient_clip_attr=gradient_clip_attr,
error_clip=error_clip) error_clip=error_clip)
else: else:
var = ParamBase( var = Parameter(
self,
d.shape(), d.shape(),
d.dtype(), d.dtype(),
type=orig_var_type, type=orig_var_type,
...@@ -2444,10 +2447,10 @@ class Block(object): ...@@ -2444,10 +2447,10 @@ class Block(object):
def create_parameter(self, *args, **kwargs): def create_parameter(self, *args, **kwargs):
global_block = self.program.global_block() global_block = self.program.global_block()
param = None param = None
if not in_dygraph_mode(): if in_dygraph_mode():
param = Parameter(global_block, *args, **kwargs)
else:
param = ParamBase(*args, **kwargs) param = ParamBase(*args, **kwargs)
else:
param = Parameter(global_block, *args, **kwargs)
if 'initializer' in kwargs: if 'initializer' in kwargs:
def _is_inited_by(block, var): def _is_inited_by(block, var):
...@@ -2687,9 +2690,8 @@ class Block(object): ...@@ -2687,9 +2690,8 @@ class Block(object):
"same topology") "same topology")
assert isinstance(v, Variable) assert isinstance(v, Variable)
new_p = None new_p = None
if not in_dygraph_mode(): if in_dygraph_mode():
new_p = Parameter( new_p = ParamBase(
block=self,
shape=v.shape, shape=v.shape,
dtype=v.dtype, dtype=v.dtype,
type=v.type, type=v.type,
...@@ -2702,7 +2704,8 @@ class Block(object): ...@@ -2702,7 +2704,8 @@ class Block(object):
error_clip=p.error_clip, error_clip=p.error_clip,
name=v.name) name=v.name)
else: else:
new_p = ParamBase( new_p = Parameter(
block=self,
shape=v.shape, shape=v.shape,
dtype=v.dtype, dtype=v.dtype,
type=v.type, type=v.type,
......
...@@ -4346,24 +4346,23 @@ def split(input, num_or_sections, dim=-1, name=None): ...@@ -4346,24 +4346,23 @@ def split(input, num_or_sections, dim=-1, name=None):
if isinstance(num_or_sections, int): if isinstance(num_or_sections, int):
num = num_or_sections num = num_or_sections
attrs['num'] = num_or_sections attrs['num'] = num_or_sections
res = core.ops.split(inputs, attrs, {}, {'Out': num}) elif isinstance(num_or_sections, (list, tuple)):
return res['Out']
elif isinstance(num_or_sections, list):
num = len(num_or_sections) num = len(num_or_sections)
attrs['sections'] = list( if utils._contain_var(num_or_sections):
map(lambda ele: -1 if isinstance(ele, Variable) else ele,
num_or_sections))
contain_var = not all(not isinstance(ele, Variable)
for ele in num_or_sections)
if contain_var:
raise TypeError( raise TypeError(
"The type of 'num_or_sections' in split must be int or list[int] in Dygraph mode, but " "The type of 'num_or_sections' in split must be int or list[int] or tuple[int] in Dygraph mode, but "
"received %s." % ('list[Variable]')) "received %s, which contains Variable." %
(type(num_or_sections)))
else:
attrs['sections'] = list(num_or_sections)
else: else:
raise TypeError( raise TypeError(
"The type of 'num_or_sections' in split must be int or list in Dygraph mode, but " "The type of 'num_or_sections' in split must be int or list in Dygraph mode, but "
"received %s." % (type(num_or_sections))) "received %s." % (type(num_or_sections)))
res = core.ops.split(inputs, attrs, {}, {'Out': num})
return res['Out']
if not isinstance(num_or_sections, (int, list, tuple)): if not isinstance(num_or_sections, (int, list, tuple)):
raise TypeError( raise TypeError(
"The type of 'num_or_sections' in split must be int, list or " "The type of 'num_or_sections' in split must be int, list or "
...@@ -4422,9 +4421,7 @@ def split(input, num_or_sections, dim=-1, name=None): ...@@ -4422,9 +4421,7 @@ def split(input, num_or_sections, dim=-1, name=None):
attrs['sections'] = list( attrs['sections'] = list(
map(lambda ele: -1 if isinstance(ele, Variable) else ele, map(lambda ele: -1 if isinstance(ele, Variable) else ele,
num_or_sections)) num_or_sections))
contain_var = not all(not isinstance(ele, Variable) if utils._contain_var(num_or_sections):
for ele in num_or_sections)
if contain_var:
inputs['SectionsTensorList'] = _get_SectionsTensorList( inputs['SectionsTensorList'] = _get_SectionsTensorList(
num_or_sections) num_or_sections)
...@@ -5572,16 +5569,14 @@ def reshape(x, shape, actual_shape=None, act=None, inplace=False, name=None): ...@@ -5572,16 +5569,14 @@ def reshape(x, shape, actual_shape=None, act=None, inplace=False, name=None):
# the shape of reshaped_3 is [6,8]. # the shape of reshaped_3 is [6,8].
""" """
if in_dygraph_mode(): if in_dygraph_mode():
#TODO(zhiqiu): open inplace if we can. #TODO(zhiqiu): enable inplace in dygraph mode.
if inplace: if inplace:
warnings.warn( warnings.warn(
"Inplace on reshape is not allowed and will be discarded in dygraph mode currently." "Inplace on reshape is not allowed and will be discarded in dygraph mode currently."
) )
attrs = {} attrs = {}
if isinstance(shape, (list, tuple)): if isinstance(shape, (list, tuple)):
contain_var = not all(not isinstance(ele, Variable) if utils._contain_var(shape):
for ele in shape)
if contain_var:
raise TypeError( raise TypeError(
"The type of 'shape' in reshape must be list[int] or tuple(int) in Dygraph mode, but " "The type of 'shape' in reshape must be list[int] or tuple(int) in Dygraph mode, but "
"received %s, which contains Variable." % type(shape)) "received %s, which contains Variable." % type(shape))
...@@ -5604,12 +5599,6 @@ def reshape(x, shape, actual_shape=None, act=None, inplace=False, name=None): ...@@ -5604,12 +5599,6 @@ def reshape(x, shape, actual_shape=None, act=None, inplace=False, name=None):
helper = LayerHelper("reshape2", **locals()) helper = LayerHelper("reshape2", **locals())
def contain_var(one_list):
for ele in one_list:
if isinstance(ele, Variable):
return True
return False
def get_new_shape_tensor(list_shape): def get_new_shape_tensor(list_shape):
new_shape_tensor = [] new_shape_tensor = []
for dim in list_shape: for dim in list_shape:
...@@ -5659,7 +5648,7 @@ def reshape(x, shape, actual_shape=None, act=None, inplace=False, name=None): ...@@ -5659,7 +5648,7 @@ def reshape(x, shape, actual_shape=None, act=None, inplace=False, name=None):
assert len(shape) > 0, ("The size of 'shape' in reshape can't be zero, " assert len(shape) > 0, ("The size of 'shape' in reshape can't be zero, "
"but received %s." % len(shape)) "but received %s." % len(shape))
attrs["shape"] = get_attr_shape(shape) attrs["shape"] = get_attr_shape(shape)
if contain_var(shape): if utils._contain_var(shape):
inputs['ShapeTensor'] = get_new_shape_tensor(shape) inputs['ShapeTensor'] = get_new_shape_tensor(shape)
elif isinstance(actual_shape, Variable): elif isinstance(actual_shape, Variable):
actual_shape.stop_gradient = True actual_shape.stop_gradient = True
...@@ -5804,8 +5793,7 @@ def unsqueeze(input, axes, name=None): ...@@ -5804,8 +5793,7 @@ def unsqueeze(input, axes, name=None):
axes.stop_gradient = True axes.stop_gradient = True
inputs["AxesTensor"] = axes inputs["AxesTensor"] = axes
elif isinstance(axes, (list, tuple)): elif isinstance(axes, (list, tuple)):
contain_var = not all(not isinstance(ele, Variable) for ele in axes) if utils._contain_var(axes):
if contain_var:
inputs["AxesTensorList"] = _to_Variable_list(axes) inputs["AxesTensorList"] = _to_Variable_list(axes)
else: else:
attrs["axes"] = axes attrs["axes"] = axes
...@@ -8256,12 +8244,6 @@ def crop_tensor(x, shape=None, offsets=None, name=None): ...@@ -8256,12 +8244,6 @@ def crop_tensor(x, shape=None, offsets=None, name=None):
ipts = {'X': x} ipts = {'X': x}
attrs = {} attrs = {}
def _contain_var(input_list):
for ele in input_list:
if isinstance(ele, Variable):
return True
return False
def _attr_shape_check(shape_val): def _attr_shape_check(shape_val):
if not isinstance(shape_val, int): if not isinstance(shape_val, int):
raise TypeError( raise TypeError(
...@@ -8290,7 +8272,7 @@ def crop_tensor(x, shape=None, offsets=None, name=None): ...@@ -8290,7 +8272,7 @@ def crop_tensor(x, shape=None, offsets=None, name=None):
offsets.stop_gradient = True offsets.stop_gradient = True
ipts['Offsets'] = offsets ipts['Offsets'] = offsets
attrs['offsets'] = [-1] * len(x.shape) attrs['offsets'] = [-1] * len(x.shape)
elif _contain_var(offsets): elif utils._contain_var(offsets):
new_offsets_tensor = [] new_offsets_tensor = []
offsets_attr = [] offsets_attr = []
for dim in offsets: for dim in offsets:
...@@ -8314,7 +8296,7 @@ def crop_tensor(x, shape=None, offsets=None, name=None): ...@@ -8314,7 +8296,7 @@ def crop_tensor(x, shape=None, offsets=None, name=None):
if isinstance(shape, Variable): if isinstance(shape, Variable):
shape.stop_gradient = True shape.stop_gradient = True
ipts['Shape'] = shape ipts['Shape'] = shape
elif _contain_var(shape): elif utils._contain_var(shape):
new_shape_tensor = [] new_shape_tensor = []
shape_attr = [] shape_attr = []
for dim_size in shape: for dim_size in shape:
...@@ -9344,20 +9326,12 @@ def expand(x, expand_times, name=None): ...@@ -9344,20 +9326,12 @@ def expand(x, expand_times, name=None):
expanded_2 = fluid.layers.expand(data_2, expand_times=expand_times) expanded_2 = fluid.layers.expand(data_2, expand_times=expand_times)
# the shape of expanded_2 is [48, 56]. # the shape of expanded_2 is [48, 56].
""" """
def contain_var(expand_times):
for ele in expand_times:
if isinstance(ele, Variable):
return True
return False
inputs = {"X": [x]} inputs = {"X": [x]}
attrs = {} attrs = {}
if in_dygraph_mode(): if in_dygraph_mode():
if isinstance(expand_times, (list, tuple)): if isinstance(expand_times, (list, tuple)):
contain_var = contain_var(expand_times) if utils._contain_var(expand_times):
if contain_var:
raise TypeError( raise TypeError(
"The type of 'expand_times' in expand must be list[int] or tuple(int) in Dygraph mode, but " "The type of 'expand_times' in expand must be list[int] or tuple(int) in Dygraph mode, but "
"received %s, which contains Variable." % type(shape)) "received %s, which contains Variable." % type(shape))
...@@ -9404,16 +9378,12 @@ def expand(x, expand_times, name=None): ...@@ -9404,16 +9378,12 @@ def expand(x, expand_times, name=None):
new_expand_times_tensor.append(temp_out) new_expand_times_tensor.append(temp_out)
return new_expand_times_tensor return new_expand_times_tensor
if in_dygraph_mode():
inputs = {'X': x}
attrs = {'expand_times': expand_times}
else:
if isinstance(expand_times, Variable): if isinstance(expand_times, Variable):
expand_times.stop_gradient = True expand_times.stop_gradient = True
inputs['ExpandTimes'] = expand_times inputs['ExpandTimes'] = expand_times
elif isinstance(expand_times, (list, tuple)): elif isinstance(expand_times, (list, tuple)):
attrs['expand_times'] = get_attr_expand_times(expand_times) attrs['expand_times'] = get_attr_expand_times(expand_times)
if contain_var(expand_times): if utils._contain_var(expand_times):
inputs['expand_times_tensor'] = get_new_expand_times_tensor( inputs['expand_times_tensor'] = get_new_expand_times_tensor(
expand_times) expand_times)
...@@ -9912,19 +9882,12 @@ def slice(input, axes, starts, ends): ...@@ -9912,19 +9882,12 @@ def slice(input, axes, starts, ends):
sliced_2 = fluid.layers.slice(input, axes=axes, starts=[minus_3, 0, 2], ends=ends) sliced_2 = fluid.layers.slice(input, axes=axes, starts=[minus_3, 0, 2], ends=ends)
# sliced_2 is input[0:3, 0:2, 2:4]. # sliced_2 is input[0:3, 0:2, 2:4].
""" """
def contain_var(one_list):
for ele in one_list:
if isinstance(ele, Variable):
return True
return False
if in_dygraph_mode(): if in_dygraph_mode():
infer_flags = list(1 for i in range(len(axes))) infer_flags = list(1 for i in range(len(axes)))
inputs = {'Input': [input]} inputs = {'Input': [input]}
if isinstance(starts, (list, tuple)): if isinstance(starts, (list, tuple)):
if contain_var(starts): if utils._contain_var(starts):
raise TypeError( raise TypeError(
"The type of 'starts' in slice must be list[int] or tuple(int) in Dygraph mode, but " "The type of 'starts' in slice must be list[int] or tuple(int) in Dygraph mode, but "
"received %s, which contains Variable." % type(shape)) "received %s, which contains Variable." % type(shape))
...@@ -9934,7 +9897,7 @@ def slice(input, axes, starts, ends): ...@@ -9934,7 +9897,7 @@ def slice(input, axes, starts, ends):
"received %s." % type(shape)) "received %s." % type(shape))
if isinstance(ends, (list, tuple)): if isinstance(ends, (list, tuple)):
if contain_var(ends): if utils._contain_var(ends):
raise TypeError( raise TypeError(
"The type of 'ends' in slice must be list[int] or tuple(int) in Dygraph mode, but " "The type of 'ends' in slice must be list[int] or tuple(int) in Dygraph mode, but "
"received %s, which contains Variable." % type(shape)) "received %s, which contains Variable." % type(shape))
...@@ -9985,9 +9948,7 @@ def slice(input, axes, starts, ends): ...@@ -9985,9 +9948,7 @@ def slice(input, axes, starts, ends):
infer_flags = list(-1 for i in range(len(axes))) infer_flags = list(-1 for i in range(len(axes)))
elif isinstance(starts, (list, tuple)): elif isinstance(starts, (list, tuple)):
attrs['starts'] = [] attrs['starts'] = []
if not contain_var(starts): if utils._contain_var(starts):
attrs['starts'] = starts
else:
inputs['StartsTensorList'] = get_new_list_tensor(starts) inputs['StartsTensorList'] = get_new_list_tensor(starts)
for i, dim in enumerate(starts): for i, dim in enumerate(starts):
if isinstance(dim, Variable): if isinstance(dim, Variable):
...@@ -9995,6 +9956,8 @@ def slice(input, axes, starts, ends): ...@@ -9995,6 +9956,8 @@ def slice(input, axes, starts, ends):
infer_flags[i] = -1 infer_flags[i] = -1
else: else:
attrs['starts'].append(dim) attrs['starts'].append(dim)
else:
attrs['starts'] = starts
# ends # ends
if isinstance(ends, Variable): if isinstance(ends, Variable):
...@@ -10003,9 +9966,7 @@ def slice(input, axes, starts, ends): ...@@ -10003,9 +9966,7 @@ def slice(input, axes, starts, ends):
infer_flags = list(-1 for i in range(len(axes))) infer_flags = list(-1 for i in range(len(axes)))
elif isinstance(ends, (list, tuple)): elif isinstance(ends, (list, tuple)):
attrs['ends'] = [] attrs['ends'] = []
if not contain_var(ends): if utils._contain_var(ends):
attrs['ends'] = ends
else:
inputs['EndsTensorList'] = get_new_list_tensor(ends) inputs['EndsTensorList'] = get_new_list_tensor(ends)
for i, dim in enumerate(ends): for i, dim in enumerate(ends):
if isinstance(dim, Variable): if isinstance(dim, Variable):
...@@ -10013,6 +9974,9 @@ def slice(input, axes, starts, ends): ...@@ -10013,6 +9974,9 @@ def slice(input, axes, starts, ends):
infer_flags[i] = -1 infer_flags[i] = -1
else: else:
attrs['ends'].append(dim) attrs['ends'].append(dim)
else:
attrs['ends'] = ends
# infer_flags # infer_flags
attrs['infer_flags'] = infer_flags attrs['infer_flags'] = infer_flags
out = helper.create_variable_for_type_inference( out = helper.create_variable_for_type_inference(
...@@ -10130,12 +10094,6 @@ def strided_slice(input, axes, starts, ends, strides): ...@@ -10130,12 +10094,6 @@ def strided_slice(input, axes, starts, ends, strides):
helper = LayerHelper('strided_slice', **locals()) helper = LayerHelper('strided_slice', **locals())
def contain_var(one_list):
for ele in one_list:
if isinstance(ele, Variable):
return True
return False
def get_new_list_tensor(old_list): def get_new_list_tensor(old_list):
new_list_tensor = [] new_list_tensor = []
for dim in old_list: for dim in old_list:
...@@ -10169,9 +10127,7 @@ def strided_slice(input, axes, starts, ends, strides): ...@@ -10169,9 +10127,7 @@ def strided_slice(input, axes, starts, ends, strides):
inputs['StartsTensor'] = starts inputs['StartsTensor'] = starts
elif isinstance(starts, (list, tuple)): elif isinstance(starts, (list, tuple)):
attrs['starts'] = [] attrs['starts'] = []
if not contain_var(starts): if utils._contain_var(starts):
attrs['starts'] = starts
else:
inputs['StartsTensorList'] = get_new_list_tensor(starts) inputs['StartsTensorList'] = get_new_list_tensor(starts)
for i, dim in enumerate(starts): for i, dim in enumerate(starts):
if isinstance(dim, Variable): if isinstance(dim, Variable):
...@@ -10179,6 +10135,8 @@ def strided_slice(input, axes, starts, ends, strides): ...@@ -10179,6 +10135,8 @@ def strided_slice(input, axes, starts, ends, strides):
infer_flags[i] = -1 infer_flags[i] = -1
else: else:
attrs['starts'].append(dim) attrs['starts'].append(dim)
else:
attrs['starts'] = starts
# ends # ends
if isinstance(ends, Variable): if isinstance(ends, Variable):
...@@ -10186,9 +10144,7 @@ def strided_slice(input, axes, starts, ends, strides): ...@@ -10186,9 +10144,7 @@ def strided_slice(input, axes, starts, ends, strides):
inputs['EndsTensor'] = ends inputs['EndsTensor'] = ends
elif isinstance(ends, (list, tuple)): elif isinstance(ends, (list, tuple)):
attrs['ends'] = [] attrs['ends'] = []
if not contain_var(ends): if utils._contain_var(ends):
attrs['ends'] = ends
else:
inputs['EndsTensorList'] = get_new_list_tensor(ends) inputs['EndsTensorList'] = get_new_list_tensor(ends)
for i, dim in enumerate(ends): for i, dim in enumerate(ends):
if isinstance(dim, Variable): if isinstance(dim, Variable):
...@@ -10196,15 +10152,16 @@ def strided_slice(input, axes, starts, ends, strides): ...@@ -10196,15 +10152,16 @@ def strided_slice(input, axes, starts, ends, strides):
infer_flags[i] = -1 infer_flags[i] = -1
else: else:
attrs['ends'].append(dim) attrs['ends'].append(dim)
else:
attrs['ends'] = ends
# strides # strides
if isinstance(strides, Variable): if isinstance(strides, Variable):
strides.stop_gradient = True strides.stop_gradient = True
inputs['StridesTensor'] = strides inputs['StridesTensor'] = strides
elif isinstance(strides, (list, tuple)): elif isinstance(strides, (list, tuple)):
attrs['strides'] = [] attrs['strides'] = []
if not contain_var(strides): if utils._contain_var(strides):
attrs['strides'] = strides
else:
inputs['StridesTensorList'] = get_new_list_tensor(strides) inputs['StridesTensorList'] = get_new_list_tensor(strides)
for i, dim in enumerate(strides): for i, dim in enumerate(strides):
if isinstance(dim, Variable): if isinstance(dim, Variable):
...@@ -10212,6 +10169,8 @@ def strided_slice(input, axes, starts, ends, strides): ...@@ -10212,6 +10169,8 @@ def strided_slice(input, axes, starts, ends, strides):
infer_flags[i] = -1 infer_flags[i] = -1
else: else:
attrs['strides'].append(dim) attrs['strides'].append(dim)
else:
attrs['strides'] = strides
attrs['infer_flags'] = infer_flags attrs['infer_flags'] = infer_flags
out = helper.create_variable_for_type_inference( out = helper.create_variable_for_type_inference(
dtype=helper.input_dtype('input')) dtype=helper.input_dtype('input'))
...@@ -13894,12 +13853,6 @@ def uniform_random(shape, dtype='float32', min=-1.0, max=1.0, seed=0): ...@@ -13894,12 +13853,6 @@ def uniform_random(shape, dtype='float32', min=-1.0, max=1.0, seed=0):
dtype = convert_np_dtype_to_dtype_(dtype) dtype = convert_np_dtype_to_dtype_(dtype)
check_dtype(dtype, 'dtype', ['float32', 'float64'], 'uniform_random') check_dtype(dtype, 'dtype', ['float32', 'float64'], 'uniform_random')
def contain_var(one_list):
for ele in one_list:
if isinstance(ele, Variable):
return True
return False
def get_new_shape_tensor(list_shape): def get_new_shape_tensor(list_shape):
new_shape_tensor = [] new_shape_tensor = []
for dim in list_shape: for dim in list_shape:
...@@ -13939,7 +13892,7 @@ def uniform_random(shape, dtype='float32', min=-1.0, max=1.0, seed=0): ...@@ -13939,7 +13892,7 @@ def uniform_random(shape, dtype='float32', min=-1.0, max=1.0, seed=0):
assert len(shape) > 0, ( assert len(shape) > 0, (
"The size of argument(shape) can't be zero.") "The size of argument(shape) can't be zero.")
attrs["shape"] = get_attr_shape(shape) attrs["shape"] = get_attr_shape(shape)
if contain_var(shape): if utils._contain_var(shape):
inputs['ShapeTensorList'] = get_new_shape_tensor(shape) inputs['ShapeTensorList'] = get_new_shape_tensor(shape)
out = helper.create_variable_for_type_inference(dtype) out = helper.create_variable_for_type_inference(dtype)
......
...@@ -22,6 +22,7 @@ from ..initializer import Constant, force_init_on_cpu ...@@ -22,6 +22,7 @@ from ..initializer import Constant, force_init_on_cpu
from ..core import VarDesc from ..core import VarDesc
from .. import core from .. import core
from .layer_function_generator import templatedoc from .layer_function_generator import templatedoc
from . import utils
from ..data_feeder import check_type_and_dtype, check_type, check_dtype, convert_dtype from ..data_feeder import check_type_and_dtype, check_type, check_dtype, convert_dtype
import numpy import numpy
import warnings import warnings
...@@ -552,13 +553,6 @@ def fill_constant(shape, dtype, value, force_cpu=False, out=None): ...@@ -552,13 +553,6 @@ def fill_constant(shape, dtype, value, force_cpu=False, out=None):
shape = fluid.layers.fill_constant([1,2], "int32", 2) # shape=[2,2] shape = fluid.layers.fill_constant([1,2], "int32", 2) # shape=[2,2]
data4 = fluid.layers.fill_constant(shape=shape, dtype='bool', value=True) # data4=[[True,True],[True,True]] data4 = fluid.layers.fill_constant(shape=shape, dtype='bool', value=True) # data4=[[True,True],[True,True]]
""" """
def _contain_var(one_list):
for ele in one_list:
if isinstance(ele, Variable):
return True
return False
attrs = { attrs = {
'value': float(value), 'value': float(value),
'force_cpu': force_cpu or force_init_on_cpu() 'force_cpu': force_cpu or force_init_on_cpu()
...@@ -571,8 +565,7 @@ def fill_constant(shape, dtype, value, force_cpu=False, out=None): ...@@ -571,8 +565,7 @@ def fill_constant(shape, dtype, value, force_cpu=False, out=None):
if in_dygraph_mode(): if in_dygraph_mode():
if isinstance(shape, (list, tuple)): if isinstance(shape, (list, tuple)):
contain_var = _contain_var(shape) if utils._contain_var(shape):
if contain_var:
raise TypeError( raise TypeError(
"The type of 'shape' in fill_constant must be list[int] or tuple(int) in Dygraph mode, but " "The type of 'shape' in fill_constant must be list[int] or tuple(int) in Dygraph mode, but "
"received %s, which contains Variable." % type(shape)) "received %s, which contains Variable." % type(shape))
...@@ -644,7 +637,7 @@ def fill_constant(shape, dtype, value, force_cpu=False, out=None): ...@@ -644,7 +637,7 @@ def fill_constant(shape, dtype, value, force_cpu=False, out=None):
"The size of 'shape' in fill_constant can't be zero, " "The size of 'shape' in fill_constant can't be zero, "
"but received %s." % len(shape)) "but received %s." % len(shape))
attrs["shape"] = _get_attr_shape(shape) attrs["shape"] = _get_attr_shape(shape)
if _contain_var(shape): if utils._contain_var(shape):
inputs['ShapeTensorList'] = _get_shape_tensor(shape) inputs['ShapeTensorList'] = _get_shape_tensor(shape)
if out is None: if out is None:
......
...@@ -16,6 +16,7 @@ from __future__ import print_function ...@@ -16,6 +16,7 @@ from __future__ import print_function
import collections import collections
import six import six
import numpy as np import numpy as np
from ..framework import Variable
def convert_to_list(value, n, name, dtype=np.int): def convert_to_list(value, n, name, dtype=np.int):
...@@ -244,3 +245,13 @@ def _is_symmetric_padding(padding, data_dim): ...@@ -244,3 +245,13 @@ def _is_symmetric_padding(padding, data_dim):
if padding[i * 2] != padding[i * 2 + 1]: if padding[i * 2] != padding[i * 2 + 1]:
is_sys = False is_sys = False
return is_sys return is_sys
def _contain_var(list_or_tuple):
"""
Check whether list or tuple contains variable.
"""
for item in list_or_tuple:
if isinstance(item, Variable):
return True
return False
...@@ -134,11 +134,11 @@ class Optimizer(object): ...@@ -134,11 +134,11 @@ class Optimizer(object):
# global step if use lr decay # global step if use lr decay
if isinstance(self._learning_rate, LearningRateDecay): if isinstance(self._learning_rate, LearningRateDecay):
var_tmp = None var_tmp = None
if not framework.in_dygraph_mode(): if framework.in_dygraph_mode():
var_temp = Variable(None, name='global_step', dtype='int32')
else:
var_temp = framework._varbase_creator( var_temp = framework._varbase_creator(
None, name='global_step', dtype='int32') None, name='global_step', dtype='int32')
else:
var_temp = Variable(None, name='global_step', dtype='int32')
tensor.fill_constant( tensor.fill_constant(
[1], "int32", self._learning_rate.step_num, out=var_temp) [1], "int32", self._learning_rate.step_num, out=var_temp)
...@@ -546,10 +546,10 @@ class Optimizer(object): ...@@ -546,10 +546,10 @@ class Optimizer(object):
See examples in ``apply_gradients``. See examples in ``apply_gradients``.
""" """
act_no_grad_set = None act_no_grad_set = None
if not framework.in_dygraph_mode(): if framework.in_dygraph_mode():
act_no_grad_set = self._get_no_grad_set(loss, no_grad_set)
else:
pass pass
else:
act_no_grad_set = self._get_no_grad_set(loss, no_grad_set)
self._dtype = loss.dtype self._dtype = loss.dtype
if framework.in_dygraph_mode(): if framework.in_dygraph_mode():
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册