提交 01f4f2d7 编写于 作者: L lujun

merge confict, test=develop

......@@ -13,6 +13,7 @@ paddle.fluid.name_scope (ArgSpec(args=['prefix'], varargs=None, keywords=None, d
paddle.fluid.cuda_places (ArgSpec(args=['device_ids'], varargs=None, keywords=None, defaults=(None,)), ('document', '7d9a51fc9cf3c5245b5227080a8064c3'))
paddle.fluid.cpu_places (ArgSpec(args=['device_count'], varargs=None, keywords=None, defaults=(None,)), ('document', '4c0cd83f0b401fc2ff84c70974e5d210'))
paddle.fluid.cuda_pinned_places (ArgSpec(args=['device_count'], varargs=None, keywords=None, defaults=(None,)), ('document', 'd0c3ebd813c39958c92b78e3eef7e912'))
paddle.fluid.in_dygraph_mode (ArgSpec(args=[], varargs=None, keywords=None, defaults=None), ('document', 'f06314a1cb30c96b5808dde2219c2dae'))
paddle.fluid.Executor.__init__ (ArgSpec(args=['self', 'place'], varargs=None, keywords=None, defaults=None), ('document', '6adf97f83acf6453d4a6a4b1070f3754'))
paddle.fluid.Executor.close (ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None), ('document', 'f5369953dd0c443961cf79f7a00e1a03'))
paddle.fluid.Executor.infer_from_dataset (ArgSpec(args=['self', 'program', 'dataset', 'scope', 'thread', 'debug', 'fetch_list', 'fetch_info', 'print_period'], varargs=None, keywords=None, defaults=(None, None, None, 0, False, None, None, 100)), ('document', '9c7decb955b9c4f718114179c8985581'))
......
......@@ -22,7 +22,7 @@ __all__ = ['enabled', 'guard', 'to_variable']
def enabled():
return framework._in_dygraph_mode()
return framework.in_dygraph_mode()
@signature_safe_contextmanager
......
......@@ -16,7 +16,7 @@ from __future__ import print_function
import copy
import six
from ..framework import Parameter, _in_dygraph_mode
from ..framework import Parameter, in_dygraph_mode
from ..param_attr import ParamAttr
from .. import core
from six.moves import zip
......
......@@ -19,7 +19,7 @@ from six.moves import reduce
from .. import core
from ..layers import utils
from . import layers
from ..framework import Variable, _in_dygraph_mode, OpProtoHolder, Parameter
from ..framework import Variable, in_dygraph_mode, OpProtoHolder, Parameter
from ..param_attr import ParamAttr
from ..initializer import Normal, Constant, NumpyArrayInitializer
import numpy as np
......@@ -2128,7 +2128,7 @@ class SequenceConv(layers.Layer):
bias_attr=None,
param_attr=None,
act=None):
assert not _in_dygraph_mode(
assert not in_dygraph_mode(
), "SequenceConv is not supported by dynamic graph mode yet!"
super(SequenceConv, self).__init__(name_scope)
self._num_filters = num_filters
......@@ -2168,7 +2168,7 @@ class RowConv(layers.Layer):
future_context_size,
param_attr=None,
act=None):
assert not _in_dygraph_mode(
assert not in_dygraph_mode(
), "RowConv is not supported by dynamic graph mode yet!"
super(RowConv, self).__init__(name_scope)
self._act = act
......
......@@ -67,6 +67,7 @@ __all__ = [
'cuda_places',
'cpu_places',
'cuda_pinned_places',
'in_dygraph_mode',
]
EMPTY_VAR_NAME = core.kEmptyVarName()
......@@ -79,7 +80,10 @@ _dygraph_tracer_ = None
_dygraph_current_expected_place_ = None
def _in_dygraph_mode():
def in_dygraph_mode():
'''
Returns(bool): True if the program is running in dynamic graph mode
'''
return _dygraph_tracer_ is not None
......@@ -396,7 +400,7 @@ class Variable(object):
if not isinstance(dtype, core.VarDesc.VarType):
dtype = convert_np_dtype_to_dtype_(dtype)
if _in_dygraph_mode():
if in_dygraph_mode():
# record vars in tracer rather than blocks
self._ivar = kwargs.get("ivar", None)
if not self._ivar:
......@@ -516,7 +520,7 @@ class Variable(object):
Returns:
str: The debug string.
"""
if _in_dygraph_mode():
if in_dygraph_mode():
# TODO(panyx0718): add more dygraph debug info.
return 'name %s, dtype: %s shape: %s' % (self.name, self.dtype,
self.shape)
......@@ -549,42 +553,42 @@ class Variable(object):
@property
def stop_gradient(self):
if _in_dygraph_mode():
if in_dygraph_mode():
return self._ivar.stop_gradient
else:
return self._stop_gradient
@stop_gradient.setter
def stop_gradient(self, s):
if _in_dygraph_mode():
if in_dygraph_mode():
self._ivar.stop_gradient = s
else:
self._stop_gradient = s
@property
def persistable(self):
if _in_dygraph_mode():
if in_dygraph_mode():
return self._ivar.persistable
else:
return self.desc.persistable()
@persistable.setter
def persistable(self, p):
if _in_dygraph_mode():
if in_dygraph_mode():
return self._ivar.persistable
else:
self.desc.set_persistable(p)
@property
def name(self):
if _in_dygraph_mode():
if in_dygraph_mode():
return self._ivar.name
else:
return cpt.to_text(self.desc.name())
@name.setter
def name(self, new_name):
if _in_dygraph_mode():
if in_dygraph_mode():
self._ivar.name = new_name
else:
self.desc.set_name(new_name)
......@@ -592,14 +596,14 @@ class Variable(object):
@property
def shape(self):
# convert to tuple, make it as same as numpy API.
if _in_dygraph_mode():
if in_dygraph_mode():
return self._ivar.shape
else:
return tuple(self.desc.shape())
@property
def dtype(self):
if _in_dygraph_mode():
if in_dygraph_mode():
return self._ivar.dtype
else:
return self.desc.dtype()
......@@ -611,7 +615,7 @@ class Variable(object):
@property
def type(self):
if _in_dygraph_mode():
if in_dygraph_mode():
return self._ivar.dtype
else:
return self.desc.type()
......@@ -930,7 +934,7 @@ class Operator(object):
inputs=None,
outputs=None,
attrs=None):
if _in_dygraph_mode():
if in_dygraph_mode():
if type is None:
raise ValueError(
"`type` to initialized an Operator can not be None.")
......@@ -1049,7 +1053,7 @@ class Operator(object):
for arg in out_args:
out_arg_names.append(cpt.to_text(arg.name))
# TODO(minqiyang): could we remove variable's op in static mode?
if not _in_dygraph_mode():
if not in_dygraph_mode():
arg.op = self
self.desc.set_output(out_proto.name, out_arg_names)
......@@ -1095,7 +1099,7 @@ class Operator(object):
@property
def type(self):
if _in_dygraph_mode():
if in_dygraph_mode():
return self.iop.type
else:
return self.desc.type()
......@@ -1638,7 +1642,7 @@ class Block(object):
Returns:
Operator: the append Operator.
"""
if _in_dygraph_mode():
if in_dygraph_mode():
op = Operator(
block=self,
desc=None,
......@@ -1710,7 +1714,7 @@ class Block(object):
return self.ops[start:end]
def _prepend_op(self, *args, **kwargs):
if _in_dygraph_mode():
if in_dygraph_mode():
op = Operator(
self,
None,
......
......@@ -165,7 +165,7 @@ class ConstantInitializer(Initializer):
'force_cpu': self._force_cpu or force_init_on_cpu()
},
stop_gradient=True)
if not framework._in_dygraph_mode():
if not framework.in_dygraph_mode():
var.op = op
return op
......@@ -245,7 +245,7 @@ class UniformInitializer(Initializer):
attrs={"in_dtype": out_var.dtype,
"out_dtype": var.dtype})
if not framework._in_dygraph_mode():
if not framework.in_dygraph_mode():
var.op = op
return op
......@@ -324,7 +324,7 @@ class NormalInitializer(Initializer):
outputs={"Out": var},
attrs={"in_dtype": out_var.dtype,
"out_dtype": var.dtype})
if not framework._in_dygraph_mode():
if not framework.in_dygraph_mode():
var.op = op
return op
......@@ -403,7 +403,7 @@ class TruncatedNormalInitializer(Initializer):
outputs={"Out": var},
attrs={"in_dtype": out_var.dtype,
"out_dtype": var.dtype})
if not framework._in_dygraph_mode():
if not framework.in_dygraph_mode():
var.op = op
return op
......@@ -509,7 +509,7 @@ class XavierInitializer(Initializer):
"seed": self._seed
},
stop_gradient=True)
if not framework._in_dygraph_mode():
if not framework.in_dygraph_mode():
var.op = op
return op
......@@ -610,7 +610,7 @@ class MSRAInitializer(Initializer):
"seed": self._seed
},
stop_gradient=True)
if not framework._in_dygraph_mode():
if not framework.in_dygraph_mode():
var.op = op
return op
......@@ -709,7 +709,7 @@ class BilinearInitializer(Initializer):
'shape': list(shape),
value_name: values
})
if not framework._in_dygraph_mode():
if not framework.in_dygraph_mode():
var.op = op
return op
......@@ -768,7 +768,7 @@ class NumpyArrayInitializer(Initializer):
value_name: values
},
stop_gradient=True)
if not framework._in_dygraph_mode():
if not framework.in_dygraph_mode():
var.op = op
return op
......
......@@ -17,7 +17,7 @@ from __future__ import print_function
import copy
import six
from .framework import Parameter, dtype_is_floating, _in_dygraph_mode
from .framework import Parameter, dtype_is_floating, in_dygraph_mode
from . import unique_name
from paddle.fluid.initializer import Constant, Xavier
from .param_attr import ParamAttr
......
......@@ -17,7 +17,7 @@ from __future__ import print_function
import copy
import numpy as np
from .framework import Variable, default_main_program, default_startup_program, _in_dygraph_mode, _current_expected_place
from .framework import Variable, default_main_program, default_startup_program, in_dygraph_mode, _current_expected_place
from . import unique_name
from .param_attr import ParamAttr, WeightNormParamAttr
from . import core
......@@ -54,7 +54,7 @@ class LayerHelperBase(object):
Return Variable construct from value
"""
if isinstance(value, np.ndarray):
assert _in_dygraph_mode(
assert in_dygraph_mode(
), "to_variable could only be called in dygraph mode"
if not block:
......@@ -302,7 +302,7 @@ class LayerHelperBase(object):
param = self._create_weight_normalize(attr, shape, dtype)
WeightNormParamAttr.params_with_weight_norm.append(param)
return param
if _in_dygraph_mode():
if in_dygraph_mode():
# In dygraph mode, we want the returned parameter to be
# initialized so that it can be used imperatively.
return self.main_program.global_block().create_parameter(
......@@ -370,7 +370,7 @@ class LayerHelperBase(object):
initializer: initializer to use
"""
assert isinstance(var, Variable)
if _in_dygraph_mode():
if in_dygraph_mode():
initializer(var, var.block)
else:
self.startup_program.global_block().create_var(
......
......@@ -23,7 +23,7 @@ import os
import inspect
from ..layer_helper import LayerHelper
from ..initializer import Normal, Constant, NumpyArrayInitializer
from ..framework import Variable, OpProtoHolder, _in_dygraph_mode
from ..framework import Variable, OpProtoHolder, in_dygraph_mode
from ..dygraph import base
from ..param_attr import ParamAttr
from .layer_function_generator import autodoc, templatedoc, _generate_doc_string_
......@@ -3288,7 +3288,7 @@ def layer_norm(input,
>>> dtype='float32')
>>> x = fluid.layers.layer_norm(input=data, begin_norm_axis=1)
"""
assert _in_dygraph_mode(
assert in_dygraph_mode(
) is not True, "please use FC instead of fc in dygraph mode!"
helper = LayerHelper('layer_norm', **locals())
dtype = helper.input_dtype()
......@@ -6454,7 +6454,7 @@ def squeeze(input, axes, name=None):
x = layers.data(name='x', shape=[5, 1, 10])
y = layers.sequeeze(input=x, axes=[1])
"""
assert not _in_dygraph_mode(), (
assert not in_dygraph_mode(), (
"squeeze layer is not supported in dygraph mode yet.")
helper = LayerHelper("squeeze", **locals())
out = helper.create_variable_for_type_inference(dtype=input.dtype)
......@@ -9193,7 +9193,7 @@ def _elementwise_op(helper):
op_type = helper.layer_type
x = helper.kwargs.get('x', None)
y = helper.kwargs.get('y', None)
if _in_dygraph_mode():
if in_dygraph_mode():
x = base.to_variable(x)
y = base.to_variable(y)
......
......@@ -55,7 +55,7 @@ class Optimizer(object):
"""
def __init__(self, learning_rate, regularization=None, name=None):
if framework._in_dygraph_mode():
if framework.in_dygraph_mode():
if not isinstance(learning_rate, float) and \
not isinstance(learning_rate, LearningRateDecay):
raise TypeError(
......@@ -205,7 +205,7 @@ class Optimizer(object):
name = self._name + "_" + name
if (name in self._accumulators and
param.name in self._accumulators[name]):
if framework._in_dygraph_mode():
if framework.in_dygraph_mode():
return self._accumulators[name][param.name]
raise Exception("Accumulator {} already exists for parameter {}".
format(name, param.name))
......@@ -363,7 +363,7 @@ class Optimizer(object):
See examples in `apply_gradients`.
"""
self._dtype = loss.dtype
if framework._in_dygraph_mode():
if framework.in_dygraph_mode():
if parameter_list is not None:
parameters = parameter_list
else:
......@@ -448,7 +448,7 @@ class Optimizer(object):
Returns:
list: A list of operators appended to the current program.
"""
if framework._in_dygraph_mode():
if framework.in_dygraph_mode():
with program_guard(framework.default_main_program(),
framework.default_startup_program()):
optimize_ops = self._create_optimization_pass(params_grads)
......
......@@ -16,7 +16,8 @@ from __future__ import print_function
import unittest
import paddle.fluid as fluid
from paddle.fluid import Embedding
import paddle.fluid.core as core
from paddle.fluid.dygraph.nn import Embedding
import paddle.fluid.framework as framework
from paddle.fluid.optimizer import SGDOptimizer
from paddle.fluid.dygraph.base import to_variable
......@@ -43,7 +44,7 @@ class SimpleLSTMRNN(fluid.Layer):
self.cell_array = []
self.hidden_array = []
def _build_once(self, input_embedding, init_hidden=None, init_cell=None):
def build_once(self, input_embedding, init_hidden=None, init_cell=None):
self.weight_1_arr = []
self.weight_2_arr = []
self.bias_arr = []
......@@ -175,7 +176,7 @@ class PtbModel(fluid.Layer):
default_initializer=fluid.initializer.UniformInitializer(
low=-self.init_scale, high=self.init_scale))
def _build_once(self, input, label, init_hidden, init_cell):
def build_once(self, input, label, init_hidden, init_cell):
pass
def forward(self, input, label, init_hidden, init_cell):
......@@ -277,7 +278,8 @@ class TestDygraphPtbRnn(unittest.TestCase):
num_steps=num_steps,
init_scale=init_scale)
exe = fluid.Executor(fluid.CPUPlace())
exe = fluid.Executor(fluid.CPUPlace(
) if not core.is_compiled_with_cuda() else fluid.CUDAPlace(0))
sgd = SGDOptimizer(learning_rate=1e-3)
x = fluid.layers.data(
name="x", shape=[-1, num_steps, 1], dtype='int64')
......@@ -331,18 +333,15 @@ class TestDygraphPtbRnn(unittest.TestCase):
static_param_updated[static_param_name_list[k -
3]] = out[k]
self.assertTrue(np.allclose(static_loss_value, dy_loss.numpy()))
self.assertTrue(np.allclose(static_last_cell_value, last_cell.numpy()))
self.assertTrue(np.array_equal(static_loss_value, dy_loss.numpy()))
self.assertTrue(
np.allclose(static_last_hidden_value, last_hidden.numpy()))
np.array_equal(static_last_cell_value, last_cell.numpy()))
self.assertTrue(
np.array_equal(static_last_hidden_value, last_hidden.numpy()))
for key, value in six.iteritems(static_param_init):
# print("static_init name: {}, value {}".format(key, value))
# print("dy_init name: {}, value {}".format(key, dy_param_init[key]))
self.assertTrue(np.allclose(value, dy_param_init[key]))
self.assertTrue(np.array_equal(value, dy_param_init[key]))
for key, value in six.iteritems(static_param_updated):
# print("static name: {}, value {}".format(key, value))
# print("dy name: {}, value {}".format(key, dy_param_updated[key]))
self.assertTrue(np.allclose(value, dy_param_updated[key]))
self.assertTrue(np.array_equal(value, dy_param_updated[key]))
if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册