提交 5ab56871 编写于 作者: Z Zhen Wang

remove no necessary doc changes. test=develop

上级 6b854f3e
......@@ -627,6 +627,183 @@ class Variable(object):
"""
self.error_clip = error_clip
def _slice_indices(self, slice, length):
"""
Reference implementation for the slice.indices method.
"""
# Compute step and length as integers.
step = 1 if slice.step is None else slice.step
# Raise ValueError for negative length or zero step.
if length < 0:
raise ValueError("length should not be negative")
if step == 0:
raise ValueError("slice step cannot be zero")
# Find lower and upper bounds for start and stop.
lower = -1 if step < 0 else 0
upper = length - 1 if step < 0 else length
# Compute start.
if slice.start is None:
start = upper if step < 0 else lower
else:
start = slice.start
start = max(start + length, lower) if start < 0 else min(start,
upper)
# Compute stop.
if slice.stop is None:
stop = lower if step < 0 else upper
else:
stop = slice.stop
stop = max(stop + length, lower) if stop < 0 else min(stop, upper)
return start, stop, step
def _detectEllipsis(self, item):
has_ellipsis = False
start = 0
end = len(self.shape)
for index, o in enumerate(item):
if o is Ellipsis:
if has_ellipsis:
raise ValueError("Index can have one ellipsis only.")
has_ellipsis = True
start = index
else:
if has_ellipsis:
end = index
return has_ellipsis, start, end
def _reconstructSliceinfo(self, item):
has_ellipsis, start, end = self._detectEllipsis(item)
if has_ellipsis:
newitem = []
for i in range(start):
newitem.append(item[i])
for i in range(start, end):
newitem.append(slice(None, None, None))
for i in range(end, len(item)):
newitem.append(item[i])
return newitem
else:
return None
def _detectContinuesSlice(self, item):
starts = []
ends = []
for index, o in enumerate(item):
if isinstance(o, int):
start = int(o)
if (index > 0 and index >= self.shape[index]) \
or (index < 0 and (index + self.shape[index]) < 0):
raise IndexError("invalid index")
start = max(start + self.shape[index], 0) if start < 0 else min(
start, self.shape[index])
starts.append(start)
ends.append(start + 1)
elif isinstance(o, slice):
start, stop, step = self._slice_indices(o, self.shape[index])
if step == 1 or step == -1:
starts.append(start)
ends.append(stop)
else:
return False, None
else:
raise IndexError("Valid index accept int or slice or ellipsis")
return True, [starts, ends]
def _cloneVar(self, copy=False):
if not copy:
return self.block.create_var(
name=unique_name.generate(".".join(self.name)),
dtype=self.dtype,
persistable=self.persistable,
stop_gradient=self._stop_gradient, )
else:
return self
def _sliceVar(self, axes, starts, ends):
new_var = self._cloneVar()
self.block.append_op(
type="slice",
inputs={'Input': [self]},
outputs={'Out': [new_var]},
attrs={'axes': axes,
'starts': starts,
'ends': ends})
return new_var
def _concatVar(self, inputs, axis):
new_var = self._cloneVar()
self.block.append_op(
type="concat",
inputs={'X': inputs},
outputs={'Out': [new_var]},
attrs={'axis': axis, })
return new_var
def _sliceAndConcatVar(self, item, axis):
if isinstance(item, slice):
if self.shape[axis] < 0:
return self._cloneVar(True)
start, stop, step = self._slice_indices(item, self.shape[axis])
if step == 1:
return self._sliceVar([axis], [start], [stop])
else:
vars = []
if step > 0:
while start < stop:
vars.append(
self._sliceVar([axis], [start], [start + 1]))
start += step
else:
while start > stop:
vars.append(
self._sliceVar([axis], [start], [start + 1]))
start += step
return self._concatVar(vars, axis)
elif isinstance(item, int):
if self.shape[axis] < 0:
return self._cloneVar(True)
index = int(item)
if (index > 0 and index >= self.shape[axis])\
or (index < 0 and (index + self.shape[axis]) < 0):
raise IndexError("invalid index")
return self._sliceVar([axis], [index], [index + 1])
else:
raise IndexError("Valid index accept int or slice or tuple")
def __getitem__(self, item):
"""
Slice the variable.
Args:
item(int/slice/tuple) : the index.
Returns:
Sliced variable
"""
new_var = None
if isinstance(item, tuple):
if len(item) > len(self.shape):
raise IndexError("Too many indexes")
newitem = self._reconstructSliceinfo(item) or item
check, info = self._detectContinuesSlice(newitem)
if check:
starts = info[0]
ends = info[1]
axes = [i for i in range(len(starts))]
return self._sliceVar(axes, starts, ends)
else:
new_var = self
for index, o in enumerate(newitem):
new_var = new_var._sliceAndConcatVar(o, index)
else:
new_var = self._sliceAndConcatVar(item, 0)
return new_var
def get_all_op_protos():
"""
......@@ -744,7 +921,7 @@ class Operator(object):
if _in_imperative_mode():
if type is None:
raise ValueError(
"`type` to initilized an Operator can not be None.")
"`type` to initialized an Operator can not be None.")
self.iop = core.OpBase(type)
# TODO(minqiyang): remove these lines after we take apart all
......@@ -906,6 +1083,9 @@ class Operator(object):
@property
def type(self):
if _in_imperative_mode():
return self.iop.type
else:
return self.desc.type()
def input(self, name):
......@@ -1022,6 +1202,9 @@ class Operator(object):
"""
self._update_desc_attr(name, val)
def _remove_attr(self, name):
self.desc.remove_attr(name)
def _update_desc_attr(self, name, val):
"""
Update the value of desc's attribute by attribute's name.
......@@ -2515,6 +2698,10 @@ class Program(object):
self._trainers_endpoints = []
# the distributed lookup table names
self._distributed_lookup_table = None
# use Deep gradient comrepssion or not
self._enable_dgc = False
# @deprecated(the python memory optimize transpiler is deprecated)
# whether the program is optimized by memory_optimize_transpiler
self.__is_mem_optimized = False
......@@ -2565,6 +2752,15 @@ class Program(object):
def set_op_role_var(self, var_name):
self._op_role_var = [var_name]
@contextlib.contextmanager
def _backward_role_guard(self):
tmp_role = self._current_role
OpRole = core.op_proto_and_checker_maker.OpRole
self._current_role = OpRole.Backward
yield
self._current_role = tmp_role
@signature_safe_contextmanager
def _optimized_guard(self, param_and_grads):
"""
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册