提交 51cb918a 编写于 作者: L Leo Chen 提交者: hong

update layers used in transformer dygraph model, test=develop (#22051)

* update layers, test=develop

* update layers for resnet, test=develop

* fix is_test attr, test=develop

* update cycle_gan, test=develop

* update reinforcement_learning, test=develop

* update ocr, test=develop

* fix bug, test=develop
上级 5b2e98aa
...@@ -19,7 +19,7 @@ from .. import core ...@@ -19,7 +19,7 @@ from .. import core
from ..layers import utils from ..layers import utils
from ..dygraph import dygraph_utils from ..dygraph import dygraph_utils
from . import layers from . import layers
from ..framework import Variable, in_dygraph_mode, OpProtoHolder, Parameter from ..framework import Variable, in_dygraph_mode, OpProtoHolder, Parameter, _dygraph_tracer, _varbase_creator
from ..param_attr import ParamAttr from ..param_attr import ParamAttr
from ..initializer import Normal, Constant, NumpyArrayInitializer from ..initializer import Normal, Constant, NumpyArrayInitializer
from .. import unique_name from .. import unique_name
...@@ -1134,41 +1134,57 @@ class BatchNorm(layers.Layer): ...@@ -1134,41 +1134,57 @@ class BatchNorm(layers.Layer):
# mean and mean_out share the same memory # mean and mean_out share the same memory
mean_out = self._mean mean_out = self._mean
# variance and variance out share the same memory # variance and variance out share the same memory
variance_out = self._variance variance_out = self._variance
attrs = {
"momentum": self._momentum,
"epsilon": self._epsilon,
"is_test": self._is_test,
"data_layout": self._data_layout,
"use_mkldnn": False,
"fuse_with_relu": self._fuse_with_relu,
"use_global_stats": self._use_global_stats,
"trainable_statistics": self._trainable_statistics
}
saved_mean = self._helper.create_variable_for_type_inference( inputs = {
dtype=self._dtype, stop_gradient=True) "X": [input],
saved_variance = self._helper.create_variable_for_type_inference( "Scale": [self.weight],
dtype=self._dtype, stop_gradient=True) "Bias": [self.bias],
batch_norm_out = input if self._in_place else self._helper.create_variable_for_type_inference( "Mean": [self._mean],
self._dtype) "Variance": [self._variance]
}
if in_dygraph_mode():
attrs['is_test'] = not _dygraph_tracer()._train_mode
saved_mean = _varbase_creator(dtype=self._dtype)
saved_variance = _varbase_creator(dtype=self._dtype)
batch_norm_out = _varbase_creator(dtype=self._dtype)
batch_norm_out.stop_gradient = False
# inplace is not supported currently
else:
saved_mean = self._helper.create_variable_for_type_inference(
dtype=self._dtype, stop_gradient=True)
saved_variance = self._helper.create_variable_for_type_inference(
dtype=self._dtype, stop_gradient=True)
batch_norm_out = input if self._in_place else self._helper.create_variable_for_type_inference(
self._dtype)
outputs = {
"Y": [batch_norm_out],
"MeanOut": [mean_out],
"VarianceOut": [variance_out],
"SavedMean": [saved_mean],
"SavedVariance": [saved_variance]
}
if in_dygraph_mode():
outs = core.ops.batch_norm(inputs, attrs, outputs)
return dygraph_utils._append_activation_in_dygraph(
batch_norm_out, act=self._act)
self._helper.append_op( self._helper.append_op(
type="batch_norm", type="batch_norm", inputs=inputs, outputs=outputs, attrs=attrs)
inputs={
"X": input,
"Scale": self.weight,
"Bias": self.bias,
"Mean": self._mean,
"Variance": self._variance
},
outputs={
"Y": batch_norm_out,
"MeanOut": mean_out,
"VarianceOut": variance_out,
"SavedMean": saved_mean,
"SavedVariance": saved_variance
},
attrs={
"momentum": self._momentum,
"epsilon": self._epsilon,
"is_test": self._is_test,
"data_layout": self._data_layout,
"use_mkldnn": False,
"fuse_with_relu": self._fuse_with_relu,
"use_global_stats": self._use_global_stats,
"trainable_statistics": self._trainable_statistics
})
# Currently, we don't support inplace in dygraph mode # Currently, we don't support inplace in dygraph mode
return self._helper.append_activation(batch_norm_out, self._act) return self._helper.append_activation(batch_norm_out, self._act)
...@@ -1454,11 +1470,23 @@ class LayerNorm(layers.Layer): ...@@ -1454,11 +1470,23 @@ class LayerNorm(layers.Layer):
', expected input with shape [*, ' + str_normalized_shape[ ', expected input with shape [*, ' + str_normalized_shape[
1:] + ', but got input shape ' + str(input_shape)) 1:] + ', but got input shape ' + str(input_shape))
inputs = dict() inputs = dict()
inputs['X'] = input inputs['X'] = [input]
if self._scale: if self._scale:
inputs['Scale'] = self.weight inputs['Scale'] = [self.weight]
if self._shift: if self._shift:
inputs['Bias'] = self.bias inputs['Bias'] = [self.bias]
attrs = {
"epsilon": self._epsilon,
"begin_norm_axis": self._begin_norm_axis
}
if in_dygraph_mode():
outs = core.ops.layer_norm(inputs, attrs)
pre_act = outs['Y'][0]
return dygraph_utils._append_activation_in_dygraph(
pre_act, act=self._act)
# create output # create output
mean_out = self._helper.create_variable_for_type_inference( mean_out = self._helper.create_variable_for_type_inference(
dtype=self._dtype, stop_gradient=True) dtype=self._dtype, stop_gradient=True)
...@@ -1623,9 +1651,22 @@ class GRUUnit(layers.Layer): ...@@ -1623,9 +1651,22 @@ class GRUUnit(layers.Layer):
attr=bias_attr, shape=bias_size, dtype=dtype, is_bias=True) attr=bias_attr, shape=bias_size, dtype=dtype, is_bias=True)
def forward(self, input, hidden): def forward(self, input, hidden):
inputs = {'Input': input, 'HiddenPrev': hidden, 'Weight': self.weight} inputs = {
'Input': [input],
'HiddenPrev': [hidden],
'Weight': [self.weight]
}
if self.bias: if self.bias:
inputs['Bias'] = self.bias inputs['Bias'] = [self.bias]
attrs = {
'activation': self.activation,
'gate_activation': self.gate_activation,
}
if in_dygraph_mode():
outs = core.ops.gru_unit(inputs, attrs)
return outs['Hidden'][0], outs['ResetHiddenPrev'][0], outs['Gate'][
0]
gate = self._helper.create_variable_for_type_inference(self._dtype) gate = self._helper.create_variable_for_type_inference(self._dtype)
reset_hidden_pre = self._helper.create_variable_for_type_inference( reset_hidden_pre = self._helper.create_variable_for_type_inference(
...@@ -2277,21 +2318,32 @@ class Conv2DTranspose(layers.Layer): ...@@ -2277,21 +2318,32 @@ class Conv2DTranspose(layers.Layer):
is_bias=True) is_bias=True)
def forward(self, input): def forward(self, input):
inputs = {'Input': [input], 'Filter': [self.weight]}
attrs = {
'output_size': self._output_size,
'strides': self._stride,
'paddings': self._padding,
'dilations': self._dilation,
'groups': self._groups,
'use_cudnn': self._use_cudnn
}
if in_dygraph_mode():
op = getattr(core.ops, self._op_type)
outs = op(inputs, attrs)
pre_bias = outs['Output'][0]
pre_act = dygraph_utils._append_bias_in_dygraph(pre_bias, self.bias,
1)
return dygraph_utils._append_activation_in_dygraph(
pre_act, act=self._act)
pre_bias = self._helper.create_variable_for_type_inference( pre_bias = self._helper.create_variable_for_type_inference(
dtype=input.dtype) dtype=input.dtype)
self._helper.append_op( self._helper.append_op(
type=self._op_type, type=self._op_type,
inputs={'Input': [input], inputs=inputs,
'Filter': [self.weight]},
outputs={'Output': pre_bias}, outputs={'Output': pre_bias},
attrs={ attrs=attrs)
'output_size': self._output_size,
'strides': self._stride,
'paddings': self._padding,
'dilations': self._dilation,
'groups': self._groups,
'use_cudnn': self._use_cudnn
})
if self.bias is not None: if self.bias is not None:
pre_act = self._helper.create_variable_for_type_inference( pre_act = self._helper.create_variable_for_type_inference(
......
...@@ -1118,6 +1118,13 @@ def softmax(input, use_cudnn=False, name=None, axis=-1): ...@@ -1118,6 +1118,13 @@ def softmax(input, use_cudnn=False, name=None, axis=-1):
fetch_list=[result[0]]) fetch_list=[result[0]])
print(output) print(output)
""" """
inputs = {"X": [input]}
attrs = {"axis": axis, "use_cudnn": use_cudnn}
if in_dygraph_mode():
outs = core.ops.softmax(inputs, attrs)
return outs['Out'][0]
helper = LayerHelper('softmax', **locals()) helper = LayerHelper('softmax', **locals())
check_type_and_dtype(input, 'input', Variable, check_type_and_dtype(input, 'input', Variable,
['float16', 'float32', 'float64'], 'softmax') ['float16', 'float32', 'float64'], 'softmax')
...@@ -1128,8 +1135,7 @@ def softmax(input, use_cudnn=False, name=None, axis=-1): ...@@ -1128,8 +1135,7 @@ def softmax(input, use_cudnn=False, name=None, axis=-1):
type="softmax", type="softmax",
inputs={"X": input}, inputs={"X": input},
outputs={"Out": softmax_out}, outputs={"Out": softmax_out},
attrs={"axis": axis, attrs=attrs)
"use_cudnn": use_cudnn})
return softmax_out return softmax_out
...@@ -5398,22 +5404,24 @@ def one_hot(input, depth, allow_out_of_range=False): ...@@ -5398,22 +5404,24 @@ def one_hot(input, depth, allow_out_of_range=False):
label = fluid.data(name="label", shape=[4, 1], dtype="int64") label = fluid.data(name="label", shape=[4, 1], dtype="int64")
one_hot_label = fluid.layers.one_hot(input=label, depth=4) one_hot_label = fluid.layers.one_hot(input=label, depth=4)
""" """
helper = LayerHelper("one_hot", **locals()) if in_dygraph_mode():
inputs = {'X': [input]}
attrs = {'depth': depth, 'allow_out_of_range': allow_out_of_range}
outs = core.ops.one_hot(inputs, attrs)
outs['Out'][0].stop_gradient = True
return outs['Out'][0]
helper = LayerHelper("one_hot", **locals())
one_hot_out = helper.create_variable_for_type_inference(dtype='float32') one_hot_out = helper.create_variable_for_type_inference(dtype='float32')
if in_dygraph_mode(): if not isinstance(depth, Variable):
# user attribute
inputs = {'X': input} inputs = {'X': input}
attrs = {'depth': depth, 'allow_out_of_range': allow_out_of_range} attrs = {'depth': depth, 'allow_out_of_range': allow_out_of_range}
else: else:
if not isinstance(depth, Variable): depth.stop_gradient = True
# user attribute inputs = {'X': input, 'depth_tensor': depth}
inputs = {'X': input} attrs = {'allow_out_of_range': allow_out_of_range}
attrs = {'depth': depth, 'allow_out_of_range': allow_out_of_range}
else:
depth.stop_gradient = True
inputs = {'X': input, 'depth_tensor': depth}
attrs = {'allow_out_of_range': allow_out_of_range}
helper.append_op( helper.append_op(
type="one_hot", type="one_hot",
inputs=inputs, inputs=inputs,
...@@ -6266,6 +6274,15 @@ def label_smooth(label, ...@@ -6266,6 +6274,15 @@ def label_smooth(label,
""" """
if epsilon > 1. or epsilon < 0.: if epsilon > 1. or epsilon < 0.:
raise ValueError("The value of epsilon must be between 0 and 1.") raise ValueError("The value of epsilon must be between 0 and 1.")
if in_dygraph_mode():
inputs = {"X": [label]}
if prior_dist:
inputs["PriorDist"] = [prior_dist]
attrs = {"epsilon": float(epsilon)}
outs = core.ops.label_smooth(inputs, attrs)
return outs['Out'][0]
helper = LayerHelper("label_smooth", **locals()) helper = LayerHelper("label_smooth", **locals())
label.stop_gradient = True label.stop_gradient = True
smooth_label = helper.create_variable_for_type_inference(dtype) smooth_label = helper.create_variable_for_type_inference(dtype)
...@@ -7839,6 +7856,11 @@ def log(x, name=None): ...@@ -7839,6 +7856,11 @@ def log(x, name=None):
res_val, = exe.run(fluid.default_main_program(), feed={'x':x_i}, fetch_list=[res]) res_val, = exe.run(fluid.default_main_program(), feed={'x':x_i}, fetch_list=[res])
print(res_val) # [[0.], [0.6931472]] print(res_val) # [[0.], [0.6931472]]
""" """
inputs = {'X': [x]}
if in_dygraph_mode():
outs = core.ops.log(inputs)
return outs['Out'][0]
helper = LayerHelper('log', **locals()) helper = LayerHelper('log', **locals())
dtype = helper.input_dtype(input_param_name='x') dtype = helper.input_dtype(input_param_name='x')
out = helper.create_variable_for_type_inference(dtype) out = helper.create_variable_for_type_inference(dtype)
...@@ -7874,6 +7896,11 @@ def relu(x, name=None): ...@@ -7874,6 +7896,11 @@ def relu(x, name=None):
# [[0. 0. ] # [[0. 0. ]
# [1. 2.6]] # [1. 2.6]]
""" """
inputs = {'X': [x]}
if in_dygraph_mode():
outs = core.ops.relu(inputs)
return outs['Out'][0]
helper = LayerHelper('relu', **locals()) helper = LayerHelper('relu', **locals())
dtype = helper.input_dtype(input_param_name='x') dtype = helper.input_dtype(input_param_name='x')
out = helper.create_variable_for_type_inference(dtype) out = helper.create_variable_for_type_inference(dtype)
...@@ -8462,6 +8489,17 @@ def pad2d(input, ...@@ -8462,6 +8489,17 @@ def pad2d(input,
result = fluid.layers.pad2d(input=data, paddings=[1, 2, 3, 4], result = fluid.layers.pad2d(input=data, paddings=[1, 2, 3, 4],
mode='reflect') mode='reflect')
""" """
attrs = {'mode': mode, 'pad_value': pad_value, 'data_format': data_format}
inputs = {'X': [input]}
if isinstance(paddings, Variable):
inputs['Paddings'] = [paddings]
attrs['paddings'] = []
else:
attrs['paddings'] = paddings
if in_dygraph_mode():
outs = core.ops.pad2d(inputs, attrs)
return outs['Out'][0]
helper = LayerHelper('pad2d', **locals()) helper = LayerHelper('pad2d', **locals())
...@@ -8470,14 +8508,6 @@ def pad2d(input, ...@@ -8470,14 +8508,6 @@ def pad2d(input,
dtype = helper.input_dtype(input_param_name='input') dtype = helper.input_dtype(input_param_name='input')
out = helper.create_variable_for_type_inference(dtype) out = helper.create_variable_for_type_inference(dtype)
inputs = {'X': input}
attrs = {'mode': mode, 'pad_value': pad_value, 'data_format': data_format}
if isinstance(paddings, Variable):
inputs['Paddings'] = paddings
attrs['paddings'] = []
else:
attrs['paddings'] = paddings
helper.append_op( helper.append_op(
type='pad2d', inputs=inputs, outputs={"Out": out}, attrs=attrs) type='pad2d', inputs=inputs, outputs={"Out": out}, attrs=attrs)
...@@ -8907,13 +8937,16 @@ def leaky_relu(x, alpha=0.02, name=None): ...@@ -8907,13 +8937,16 @@ def leaky_relu(x, alpha=0.02, name=None):
res_val, = exe.run(fluid.default_main_program(), feed={'x':x_i}, fetch_list=[res]) res_val, = exe.run(fluid.default_main_program(), feed={'x':x_i}, fetch_list=[res])
print(res_val) # [[-0.1, 2], [3, -0.4]] print(res_val) # [[-0.1, 2], [3, -0.4]]
""" """
inputs = {'X': [x]}
attrs = {'alpha': alpha}
if in_dygraph_mode():
outs = core.ops.leaky_relu(inputs, attrs)
return outs['Out'][0]
helper = LayerHelper('leaky_relu', **locals()) helper = LayerHelper('leaky_relu', **locals())
out = helper.create_variable_for_type_inference(dtype=x.dtype) out = helper.create_variable_for_type_inference(dtype=x.dtype)
helper.append_op( helper.append_op(
type='leaky_relu', type='leaky_relu', inputs=inputs, outputs={'Out': out}, attrs=attrs)
inputs={'X': x},
outputs={'Out': out},
attrs={'alpha': alpha})
return out return out
...@@ -9311,6 +9344,32 @@ def expand(x, expand_times, name=None): ...@@ -9311,6 +9344,32 @@ def expand(x, expand_times, name=None):
expanded_2 = fluid.layers.expand(data_2, expand_times=expand_times) expanded_2 = fluid.layers.expand(data_2, expand_times=expand_times)
# the shape of expanded_2 is [48, 56]. # the shape of expanded_2 is [48, 56].
""" """
def contain_var(expand_times):
for ele in expand_times:
if isinstance(ele, Variable):
return True
return False
inputs = {"X": [x]}
attrs = {}
if in_dygraph_mode():
if isinstance(expand_times, (list, tuple)):
contain_var = contain_var(expand_times)
if contain_var:
raise TypeError(
"The type of 'expand_times' in expand must be list[int] or tuple(int) in Dygraph mode, but "
"received %s, which contains Variable." % type(shape))
attrs['expand_times'] = expand_times
else:
raise TypeError(
"The type of 'expand_times' in expand must be list[int] or tuple(int) in Dygraph mode, but "
"received %s." % type(shape))
outs = core.ops.expand(inputs, attrs)
return outs['Out'][0]
check_type_and_dtype(x, 'x', Variable, check_type_and_dtype(x, 'x', Variable,
['bool', 'float32', 'float64', 'int32', 'int64'], ['bool', 'float32', 'float64', 'int32', 'int64'],
'expand') 'expand')
...@@ -9320,14 +9379,6 @@ def expand(x, expand_times, name=None): ...@@ -9320,14 +9379,6 @@ def expand(x, expand_times, name=None):
"expand op bool date type must set the stop_gradient to be False") "expand op bool date type must set the stop_gradient to be False")
helper = LayerHelper('expand', input=x, **locals()) helper = LayerHelper('expand', input=x, **locals())
inputs = {"X": x}
attrs = {}
def contain_var(expand_times):
for ele in expand_times:
if isinstance(ele, Variable):
return True
return False
def get_attr_expand_times(list_expand_times): def get_attr_expand_times(list_expand_times):
attrs_expand_times = [] attrs_expand_times = []
...@@ -10363,24 +10414,27 @@ def scale(x, scale=1.0, bias=0.0, bias_after_scale=True, act=None, name=None): ...@@ -10363,24 +10414,27 @@ def scale(x, scale=1.0, bias=0.0, bias_after_scale=True, act=None, name=None):
print(res) # [array([[ 3., 5., 7.], [ 9., 11., 13.]], dtype=float32)] print(res) # [array([[ 3., 5., 7.], [ 9., 11., 13.]], dtype=float32)]
""" """
inputs = {'X': [x]}
helper = LayerHelper('scale', **locals())
if name is None:
out = helper.create_variable_for_type_inference(dtype=x.dtype)
else:
out = helper.create_variable(
name=name, dtype=x.dtype, persistable=False)
inputs = {'X': x}
attrs = { attrs = {
'bias': float(bias), 'bias': float(bias),
'bias_after_scale': bias_after_scale, 'bias_after_scale': bias_after_scale,
} }
if isinstance(scale, Variable): if isinstance(scale, Variable):
inputs['ScaleTensor'] = scale inputs['ScaleTensor'] = [scale]
else: else:
attrs['scale'] = float(scale) attrs['scale'] = float(scale)
if in_dygraph_mode():
outs = core.ops.scale(inputs, attrs)
return dygraph_utils._append_activation_in_dygraph(outs['Out'][0])
helper = LayerHelper('scale', **locals())
if name is None:
out = helper.create_variable_for_type_inference(dtype=x.dtype)
else:
out = helper.create_variable(
name=name, dtype=x.dtype, persistable=False)
helper.append_op( helper.append_op(
type='scale', inputs=inputs, outputs={'Out': out}, attrs=attrs) type='scale', inputs=inputs, outputs={'Out': out}, attrs=attrs)
return helper.append_activation(out) return helper.append_activation(out)
...@@ -10817,6 +10871,9 @@ Examples: ...@@ -10817,6 +10871,9 @@ Examples:
print(z_value)#[[[[0., 0., 0., 0., 0.] .... [0., 0., 0., 0., 0.]]]] print(z_value)#[[[[0., 0., 0., 0., 0.] .... [0., 0., 0., 0., 0.]]]]
""" """
if in_dygraph_mode():
return _elementwise_op_in_dygraph(
x, y, axis=axis, act=act, op_name='elementwise_min')
return _elementwise_op(LayerHelper('elementwise_min', **locals())) return _elementwise_op(LayerHelper('elementwise_min', **locals()))
...@@ -11407,6 +11464,11 @@ def mul(x, y, x_num_col_dims=1, y_num_col_dims=1, name=None): ...@@ -11407,6 +11464,11 @@ def mul(x, y, x_num_col_dims=1, y_num_col_dims=1, name=None):
""" """
inputs = {"X": [x], "Y": [y]}
attrs = {"x_num_col_dims": x_num_col_dims, "y_num_col_dims": y_num_col_dims}
if in_dygraph_mode():
outs = core.ops.mul(inputs, attrs)
return outs['Out'][0]
helper = LayerHelper("mul", **locals()) helper = LayerHelper("mul", **locals())
check_type_and_dtype(x, 'x', Variable, ['float16', 'float32', 'float64'], check_type_and_dtype(x, 'x', Variable, ['float16', 'float32', 'float64'],
...@@ -11420,14 +11482,8 @@ def mul(x, y, x_num_col_dims=1, y_num_col_dims=1, name=None): ...@@ -11420,14 +11482,8 @@ def mul(x, y, x_num_col_dims=1, y_num_col_dims=1, name=None):
name=name, dtype=x.dtype, persistable=False) name=name, dtype=x.dtype, persistable=False)
helper.append_op( helper.append_op(
type="mul", type="mul", inputs={"X": x,
inputs={"X": x, "Y": y}, attrs=attrs, outputs={"Out": out})
"Y": y},
attrs={
"x_num_col_dims": x_num_col_dims,
"y_num_col_dims": y_num_col_dims
},
outputs={"Out": out})
return out return out
......
...@@ -16,7 +16,7 @@ from __future__ import print_function ...@@ -16,7 +16,7 @@ from __future__ import print_function
from six.moves import reduce from six.moves import reduce
from ..layer_helper import LayerHelper from ..layer_helper import LayerHelper
from ..param_attr import ParamAttr from ..param_attr import ParamAttr
from ..framework import convert_np_dtype_to_dtype_, in_dygraph_mode from ..framework import convert_np_dtype_to_dtype_, in_dygraph_mode, _varbase_creator
from ..framework import Variable from ..framework import Variable
from ..initializer import Constant, force_init_on_cpu from ..initializer import Constant, force_init_on_cpu
from ..core import VarDesc from ..core import VarDesc
...@@ -552,6 +552,43 @@ def fill_constant(shape, dtype, value, force_cpu=False, out=None): ...@@ -552,6 +552,43 @@ def fill_constant(shape, dtype, value, force_cpu=False, out=None):
shape = fluid.layers.fill_constant([1,2], "int32", 2) # shape=[2,2] shape = fluid.layers.fill_constant([1,2], "int32", 2) # shape=[2,2]
data4 = fluid.layers.fill_constant(shape=shape, dtype='bool', value=True) # data4=[[True,True],[True,True]] data4 = fluid.layers.fill_constant(shape=shape, dtype='bool', value=True) # data4=[[True,True],[True,True]]
""" """
def _contain_var(one_list):
for ele in one_list:
if isinstance(ele, Variable):
return True
return False
attrs = {
'value': float(value),
'force_cpu': force_cpu or force_init_on_cpu()
}
if convert_dtype(dtype) in ['int64', 'int32']:
attrs['str_value'] = str(int(value))
else:
attrs['str_value'] = str(float(value))
if in_dygraph_mode():
if isinstance(shape, (list, tuple)):
contain_var = _contain_var(shape)
if contain_var:
raise TypeError(
"The type of 'shape' in fill_constant must be list[int] or tuple(int) in Dygraph mode, but "
"received %s, which contains Variable." % type(shape))
attrs['shape'] = shape
else:
raise TypeError(
"The type of 'shape' in fill_constant must be list[int] or tuple(int) in Dygraph mode, but "
"received %s." % type(shape))
if out is None:
out = _varbase_creator(dtype=dtype)
attrs['dtype'] = out.dtype
outputs = {'Out': [out]}
outs = core.ops.fill_constant({}, attrs, outputs)
out.stop_gradient = True
return out
helper = LayerHelper("fill_constant", **locals()) helper = LayerHelper("fill_constant", **locals())
check_dtype(dtype, 'create data type', check_dtype(dtype, 'create data type',
['bool', 'float16', 'float32', 'float64', 'int32', 'int64'], ['bool', 'float16', 'float32', 'float64', 'int32', 'int64'],
...@@ -568,12 +605,6 @@ def fill_constant(shape, dtype, value, force_cpu=False, out=None): ...@@ -568,12 +605,6 @@ def fill_constant(shape, dtype, value, force_cpu=False, out=None):
else: else:
attrs['str_value'] = str(float(value)) attrs['str_value'] = str(float(value))
def _contain_var(one_list):
for ele in one_list:
if isinstance(ele, Variable):
return True
return False
def _get_attr_shape(list_shape): def _get_attr_shape(list_shape):
attr_shape = [] attr_shape = []
for idx, dim in enumerate(list_shape): for idx, dim in enumerate(list_shape):
......
...@@ -896,21 +896,30 @@ class MomentumOptimizer(Optimizer): ...@@ -896,21 +896,30 @@ class MomentumOptimizer(Optimizer):
velocity_acc = self._get_accumulator(self._velocity_acc_str, velocity_acc = self._get_accumulator(self._velocity_acc_str,
param_and_grad[0]) param_and_grad[0])
attrs = {"mu": self._momentum, "use_nesterov": self._use_nesterov}
inputs = {
"Param": [param_and_grad[0]],
"Grad": [param_and_grad[1]],
"Velocity": [velocity_acc],
"LearningRate": [self._create_param_lr(param_and_grad)]
}
outputs = {
"ParamOut": [param_and_grad[0]],
"VelocityOut": [velocity_acc]
}
if framework.in_dygraph_mode():
core.ops.momentum(inputs, attrs, outputs)
return None
# create the momentum optimize op # create the momentum optimize op
momentum_op = block.append_op( momentum_op = block.append_op(
type=self.type, type=self.type,
inputs={ inputs=inputs,
"Param": param_and_grad[0], outputs=outputs,
"Grad": param_and_grad[1], attrs=attrs,
"Velocity": velocity_acc,
"LearningRate": self._create_param_lr(param_and_grad)
},
outputs={
"ParamOut": param_and_grad[0],
"VelocityOut": velocity_acc
},
attrs={"mu": self._momentum,
"use_nesterov": self._use_nesterov},
stop_gradient=True) stop_gradient=True)
return momentum_op return momentum_op
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
from __future__ import print_function from __future__ import print_function
from . import framework from . import framework
from .framework import in_dygraph_mode, _varbase_creator
from . import core from . import core
__all__ = ['L1Decay', 'L2Decay', 'L1DecayRegularizer', 'L2DecayRegularizer'] __all__ = ['L1Decay', 'L2Decay', 'L1DecayRegularizer', 'L2DecayRegularizer']
...@@ -74,10 +75,12 @@ def append_regularization_ops(parameters_and_grads, regularization=None): ...@@ -74,10 +75,12 @@ def append_regularization_ops(parameters_and_grads, regularization=None):
lod_level=param.lod_level, lod_level=param.lod_level,
type=core.VarDesc.VarType.LOD_TENSOR) type=core.VarDesc.VarType.LOD_TENSOR)
grad.block.append_op( inputs = {"X": [grad, regularization_term]}
type='sum', outputs = {"Out": [new_grad]}
inputs={"X": [grad, regularization_term]}, if in_dygraph_mode():
outputs={"Out": new_grad}) core.ops.sum(inputs, {}, outputs)
else:
grad.block.append_op(type='sum', inputs=inputs, outputs=outputs)
params_and_grads.append((param, new_grad)) params_and_grads.append((param, new_grad))
...@@ -165,20 +168,24 @@ class L2DecayRegularizer(WeightDecayRegularizer): ...@@ -165,20 +168,24 @@ class L2DecayRegularizer(WeightDecayRegularizer):
assert isinstance(param, framework.Parameter) assert isinstance(param, framework.Parameter)
assert isinstance(block, framework.Block) assert isinstance(block, framework.Block)
inputs = {"X": [param]}
attrs = {"scale": self._regularization_coeff}
if framework.in_dygraph_mode(): if framework.in_dygraph_mode():
decay = block.create_var(dtype=param.dtype, shape=param.shape) outs = core.ops.scale(inputs, attrs)
return outs['Out'][0]
else: else:
decay = block.create_var( decay = block.create_var(
dtype=param.dtype, shape=param.shape, lod_level=param.lod_level) dtype=param.dtype, shape=param.shape, lod_level=param.lod_level)
# Append Op to calculate decay # Append Op to calculate decay
block.append_op( block.append_op(
type='scale', type='scale',
inputs={"X": param}, inputs={"X": param},
outputs={"Out": decay}, outputs={"Out": decay},
attrs={"scale": self._regularization_coeff}) attrs={"scale": self._regularization_coeff})
return decay return decay
def __str__(self): def __str__(self):
return "L2Decay, regularization_coeff=%f" % self._regularization_coeff return "L2Decay, regularization_coeff=%f" % self._regularization_coeff
......
...@@ -27,6 +27,8 @@ from test_imperative_base import new_program_scope ...@@ -27,6 +27,8 @@ from test_imperative_base import new_program_scope
from utils import DyGraphProgramDescTracerTestHelper, is_equal_program from utils import DyGraphProgramDescTracerTestHelper, is_equal_program
from paddle.fluid.dygraph import TracedLayer from paddle.fluid.dygraph import TracedLayer
#NOTE(zhiqiu): run with FLAGS_cudnn_deterministic=1
batch_size = 8 batch_size = 8
train_parameters = { train_parameters = {
"input_size": [3, 224, 224], "input_size": [3, 224, 224],
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册