From 51cb918a05f5d2d630a25d2bd1587714f6ee48cc Mon Sep 17 00:00:00 2001 From: Leo Chen Date: Thu, 9 Jan 2020 11:20:59 +0800 Subject: [PATCH] update layers used in transformer dygraph model, test=develop (#22051) * update layers, test=develop * update layers for resnet, test=develop * fix is_test attr, test=develop * update cycle_gan, test=develop * update reinforcement_learning, test=develop * update ocr, test=develop * fix bug, test=develop --- python/paddle/fluid/dygraph/nn.py | 146 ++++++++++------ python/paddle/fluid/layers/nn.py | 156 ++++++++++++------ python/paddle/fluid/layers/tensor.py | 45 ++++- python/paddle/fluid/optimizer.py | 33 ++-- python/paddle/fluid/regularizer.py | 31 ++-- .../tests/unittests/test_imperative_resnet.py | 2 + 6 files changed, 285 insertions(+), 128 deletions(-) diff --git a/python/paddle/fluid/dygraph/nn.py b/python/paddle/fluid/dygraph/nn.py index 5c28617b6e7..dc288d94661 100644 --- a/python/paddle/fluid/dygraph/nn.py +++ b/python/paddle/fluid/dygraph/nn.py @@ -19,7 +19,7 @@ from .. import core from ..layers import utils from ..dygraph import dygraph_utils from . import layers -from ..framework import Variable, in_dygraph_mode, OpProtoHolder, Parameter +from ..framework import Variable, in_dygraph_mode, OpProtoHolder, Parameter, _dygraph_tracer, _varbase_creator from ..param_attr import ParamAttr from ..initializer import Normal, Constant, NumpyArrayInitializer from .. import unique_name @@ -1134,41 +1134,57 @@ class BatchNorm(layers.Layer): # mean and mean_out share the same memory mean_out = self._mean # variance and variance out share the same memory + variance_out = self._variance + attrs = { + "momentum": self._momentum, + "epsilon": self._epsilon, + "is_test": self._is_test, + "data_layout": self._data_layout, + "use_mkldnn": False, + "fuse_with_relu": self._fuse_with_relu, + "use_global_stats": self._use_global_stats, + "trainable_statistics": self._trainable_statistics + } - saved_mean = self._helper.create_variable_for_type_inference( - dtype=self._dtype, stop_gradient=True) - saved_variance = self._helper.create_variable_for_type_inference( - dtype=self._dtype, stop_gradient=True) - batch_norm_out = input if self._in_place else self._helper.create_variable_for_type_inference( - self._dtype) + inputs = { + "X": [input], + "Scale": [self.weight], + "Bias": [self.bias], + "Mean": [self._mean], + "Variance": [self._variance] + } + + if in_dygraph_mode(): + attrs['is_test'] = not _dygraph_tracer()._train_mode + saved_mean = _varbase_creator(dtype=self._dtype) + saved_variance = _varbase_creator(dtype=self._dtype) + batch_norm_out = _varbase_creator(dtype=self._dtype) + batch_norm_out.stop_gradient = False + # inplace is not supported currently + else: + saved_mean = self._helper.create_variable_for_type_inference( + dtype=self._dtype, stop_gradient=True) + saved_variance = self._helper.create_variable_for_type_inference( + dtype=self._dtype, stop_gradient=True) + batch_norm_out = input if self._in_place else self._helper.create_variable_for_type_inference( + self._dtype) + + outputs = { + "Y": [batch_norm_out], + "MeanOut": [mean_out], + "VarianceOut": [variance_out], + "SavedMean": [saved_mean], + "SavedVariance": [saved_variance] + } + + if in_dygraph_mode(): + outs = core.ops.batch_norm(inputs, attrs, outputs) + return dygraph_utils._append_activation_in_dygraph( + batch_norm_out, act=self._act) self._helper.append_op( - type="batch_norm", - inputs={ - "X": input, - "Scale": self.weight, - "Bias": self.bias, - "Mean": self._mean, - "Variance": self._variance - }, - outputs={ - "Y": batch_norm_out, - "MeanOut": mean_out, - "VarianceOut": variance_out, - "SavedMean": saved_mean, - "SavedVariance": saved_variance - }, - attrs={ - "momentum": self._momentum, - "epsilon": self._epsilon, - "is_test": self._is_test, - "data_layout": self._data_layout, - "use_mkldnn": False, - "fuse_with_relu": self._fuse_with_relu, - "use_global_stats": self._use_global_stats, - "trainable_statistics": self._trainable_statistics - }) + type="batch_norm", inputs=inputs, outputs=outputs, attrs=attrs) # Currently, we don't support inplace in dygraph mode return self._helper.append_activation(batch_norm_out, self._act) @@ -1454,11 +1470,23 @@ class LayerNorm(layers.Layer): ', expected input with shape [*, ' + str_normalized_shape[ 1:] + ', but got input shape ' + str(input_shape)) inputs = dict() - inputs['X'] = input + inputs['X'] = [input] if self._scale: - inputs['Scale'] = self.weight + inputs['Scale'] = [self.weight] if self._shift: - inputs['Bias'] = self.bias + inputs['Bias'] = [self.bias] + + attrs = { + "epsilon": self._epsilon, + "begin_norm_axis": self._begin_norm_axis + } + + if in_dygraph_mode(): + outs = core.ops.layer_norm(inputs, attrs) + pre_act = outs['Y'][0] + return dygraph_utils._append_activation_in_dygraph( + pre_act, act=self._act) + # create output mean_out = self._helper.create_variable_for_type_inference( dtype=self._dtype, stop_gradient=True) @@ -1623,9 +1651,22 @@ class GRUUnit(layers.Layer): attr=bias_attr, shape=bias_size, dtype=dtype, is_bias=True) def forward(self, input, hidden): - inputs = {'Input': input, 'HiddenPrev': hidden, 'Weight': self.weight} + inputs = { + 'Input': [input], + 'HiddenPrev': [hidden], + 'Weight': [self.weight] + } if self.bias: - inputs['Bias'] = self.bias + inputs['Bias'] = [self.bias] + attrs = { + 'activation': self.activation, + 'gate_activation': self.gate_activation, + } + + if in_dygraph_mode(): + outs = core.ops.gru_unit(inputs, attrs) + return outs['Hidden'][0], outs['ResetHiddenPrev'][0], outs['Gate'][ + 0] gate = self._helper.create_variable_for_type_inference(self._dtype) reset_hidden_pre = self._helper.create_variable_for_type_inference( @@ -2277,21 +2318,32 @@ class Conv2DTranspose(layers.Layer): is_bias=True) def forward(self, input): + inputs = {'Input': [input], 'Filter': [self.weight]} + attrs = { + 'output_size': self._output_size, + 'strides': self._stride, + 'paddings': self._padding, + 'dilations': self._dilation, + 'groups': self._groups, + 'use_cudnn': self._use_cudnn + } + + if in_dygraph_mode(): + op = getattr(core.ops, self._op_type) + outs = op(inputs, attrs) + pre_bias = outs['Output'][0] + pre_act = dygraph_utils._append_bias_in_dygraph(pre_bias, self.bias, + 1) + return dygraph_utils._append_activation_in_dygraph( + pre_act, act=self._act) + pre_bias = self._helper.create_variable_for_type_inference( dtype=input.dtype) self._helper.append_op( type=self._op_type, - inputs={'Input': [input], - 'Filter': [self.weight]}, + inputs=inputs, outputs={'Output': pre_bias}, - attrs={ - 'output_size': self._output_size, - 'strides': self._stride, - 'paddings': self._padding, - 'dilations': self._dilation, - 'groups': self._groups, - 'use_cudnn': self._use_cudnn - }) + attrs=attrs) if self.bias is not None: pre_act = self._helper.create_variable_for_type_inference( diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index a536c0b8394..e9494566990 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -1118,6 +1118,13 @@ def softmax(input, use_cudnn=False, name=None, axis=-1): fetch_list=[result[0]]) print(output) """ + inputs = {"X": [input]} + attrs = {"axis": axis, "use_cudnn": use_cudnn} + + if in_dygraph_mode(): + outs = core.ops.softmax(inputs, attrs) + return outs['Out'][0] + helper = LayerHelper('softmax', **locals()) check_type_and_dtype(input, 'input', Variable, ['float16', 'float32', 'float64'], 'softmax') @@ -1128,8 +1135,7 @@ def softmax(input, use_cudnn=False, name=None, axis=-1): type="softmax", inputs={"X": input}, outputs={"Out": softmax_out}, - attrs={"axis": axis, - "use_cudnn": use_cudnn}) + attrs=attrs) return softmax_out @@ -5398,22 +5404,24 @@ def one_hot(input, depth, allow_out_of_range=False): label = fluid.data(name="label", shape=[4, 1], dtype="int64") one_hot_label = fluid.layers.one_hot(input=label, depth=4) """ - helper = LayerHelper("one_hot", **locals()) + if in_dygraph_mode(): + inputs = {'X': [input]} + attrs = {'depth': depth, 'allow_out_of_range': allow_out_of_range} + outs = core.ops.one_hot(inputs, attrs) + outs['Out'][0].stop_gradient = True + return outs['Out'][0] + helper = LayerHelper("one_hot", **locals()) one_hot_out = helper.create_variable_for_type_inference(dtype='float32') - if in_dygraph_mode(): + if not isinstance(depth, Variable): + # user attribute inputs = {'X': input} attrs = {'depth': depth, 'allow_out_of_range': allow_out_of_range} else: - if not isinstance(depth, Variable): - # user attribute - inputs = {'X': input} - attrs = {'depth': depth, 'allow_out_of_range': allow_out_of_range} - else: - depth.stop_gradient = True - inputs = {'X': input, 'depth_tensor': depth} - attrs = {'allow_out_of_range': allow_out_of_range} + depth.stop_gradient = True + inputs = {'X': input, 'depth_tensor': depth} + attrs = {'allow_out_of_range': allow_out_of_range} helper.append_op( type="one_hot", inputs=inputs, @@ -6266,6 +6274,15 @@ def label_smooth(label, """ if epsilon > 1. or epsilon < 0.: raise ValueError("The value of epsilon must be between 0 and 1.") + + if in_dygraph_mode(): + inputs = {"X": [label]} + if prior_dist: + inputs["PriorDist"] = [prior_dist] + attrs = {"epsilon": float(epsilon)} + outs = core.ops.label_smooth(inputs, attrs) + return outs['Out'][0] + helper = LayerHelper("label_smooth", **locals()) label.stop_gradient = True smooth_label = helper.create_variable_for_type_inference(dtype) @@ -7839,6 +7856,11 @@ def log(x, name=None): res_val, = exe.run(fluid.default_main_program(), feed={'x':x_i}, fetch_list=[res]) print(res_val) # [[0.], [0.6931472]] """ + inputs = {'X': [x]} + if in_dygraph_mode(): + outs = core.ops.log(inputs) + return outs['Out'][0] + helper = LayerHelper('log', **locals()) dtype = helper.input_dtype(input_param_name='x') out = helper.create_variable_for_type_inference(dtype) @@ -7874,6 +7896,11 @@ def relu(x, name=None): # [[0. 0. ] # [1. 2.6]] """ + inputs = {'X': [x]} + if in_dygraph_mode(): + outs = core.ops.relu(inputs) + return outs['Out'][0] + helper = LayerHelper('relu', **locals()) dtype = helper.input_dtype(input_param_name='x') out = helper.create_variable_for_type_inference(dtype) @@ -8462,6 +8489,17 @@ def pad2d(input, result = fluid.layers.pad2d(input=data, paddings=[1, 2, 3, 4], mode='reflect') """ + attrs = {'mode': mode, 'pad_value': pad_value, 'data_format': data_format} + inputs = {'X': [input]} + if isinstance(paddings, Variable): + inputs['Paddings'] = [paddings] + attrs['paddings'] = [] + else: + attrs['paddings'] = paddings + + if in_dygraph_mode(): + outs = core.ops.pad2d(inputs, attrs) + return outs['Out'][0] helper = LayerHelper('pad2d', **locals()) @@ -8470,14 +8508,6 @@ def pad2d(input, dtype = helper.input_dtype(input_param_name='input') out = helper.create_variable_for_type_inference(dtype) - inputs = {'X': input} - attrs = {'mode': mode, 'pad_value': pad_value, 'data_format': data_format} - - if isinstance(paddings, Variable): - inputs['Paddings'] = paddings - attrs['paddings'] = [] - else: - attrs['paddings'] = paddings helper.append_op( type='pad2d', inputs=inputs, outputs={"Out": out}, attrs=attrs) @@ -8907,13 +8937,16 @@ def leaky_relu(x, alpha=0.02, name=None): res_val, = exe.run(fluid.default_main_program(), feed={'x':x_i}, fetch_list=[res]) print(res_val) # [[-0.1, 2], [3, -0.4]] """ + inputs = {'X': [x]} + attrs = {'alpha': alpha} + if in_dygraph_mode(): + outs = core.ops.leaky_relu(inputs, attrs) + return outs['Out'][0] + helper = LayerHelper('leaky_relu', **locals()) out = helper.create_variable_for_type_inference(dtype=x.dtype) helper.append_op( - type='leaky_relu', - inputs={'X': x}, - outputs={'Out': out}, - attrs={'alpha': alpha}) + type='leaky_relu', inputs=inputs, outputs={'Out': out}, attrs=attrs) return out @@ -9311,6 +9344,32 @@ def expand(x, expand_times, name=None): expanded_2 = fluid.layers.expand(data_2, expand_times=expand_times) # the shape of expanded_2 is [48, 56]. """ + + def contain_var(expand_times): + for ele in expand_times: + if isinstance(ele, Variable): + return True + return False + + inputs = {"X": [x]} + attrs = {} + + if in_dygraph_mode(): + if isinstance(expand_times, (list, tuple)): + contain_var = contain_var(expand_times) + if contain_var: + raise TypeError( + "The type of 'expand_times' in expand must be list[int] or tuple(int) in Dygraph mode, but " + "received %s, which contains Variable." % type(shape)) + attrs['expand_times'] = expand_times + else: + raise TypeError( + "The type of 'expand_times' in expand must be list[int] or tuple(int) in Dygraph mode, but " + "received %s." % type(shape)) + + outs = core.ops.expand(inputs, attrs) + return outs['Out'][0] + check_type_and_dtype(x, 'x', Variable, ['bool', 'float32', 'float64', 'int32', 'int64'], 'expand') @@ -9320,14 +9379,6 @@ def expand(x, expand_times, name=None): "expand op bool date type must set the stop_gradient to be False") helper = LayerHelper('expand', input=x, **locals()) - inputs = {"X": x} - attrs = {} - - def contain_var(expand_times): - for ele in expand_times: - if isinstance(ele, Variable): - return True - return False def get_attr_expand_times(list_expand_times): attrs_expand_times = [] @@ -10363,24 +10414,27 @@ def scale(x, scale=1.0, bias=0.0, bias_after_scale=True, act=None, name=None): print(res) # [array([[ 3., 5., 7.], [ 9., 11., 13.]], dtype=float32)] """ - - helper = LayerHelper('scale', **locals()) - if name is None: - out = helper.create_variable_for_type_inference(dtype=x.dtype) - else: - out = helper.create_variable( - name=name, dtype=x.dtype, persistable=False) - - inputs = {'X': x} + inputs = {'X': [x]} attrs = { 'bias': float(bias), 'bias_after_scale': bias_after_scale, } if isinstance(scale, Variable): - inputs['ScaleTensor'] = scale + inputs['ScaleTensor'] = [scale] else: attrs['scale'] = float(scale) + if in_dygraph_mode(): + outs = core.ops.scale(inputs, attrs) + return dygraph_utils._append_activation_in_dygraph(outs['Out'][0]) + + helper = LayerHelper('scale', **locals()) + if name is None: + out = helper.create_variable_for_type_inference(dtype=x.dtype) + else: + out = helper.create_variable( + name=name, dtype=x.dtype, persistable=False) + helper.append_op( type='scale', inputs=inputs, outputs={'Out': out}, attrs=attrs) return helper.append_activation(out) @@ -10817,6 +10871,9 @@ Examples: print(z_value)#[[[[0., 0., 0., 0., 0.] .... [0., 0., 0., 0., 0.]]]] """ + if in_dygraph_mode(): + return _elementwise_op_in_dygraph( + x, y, axis=axis, act=act, op_name='elementwise_min') return _elementwise_op(LayerHelper('elementwise_min', **locals())) @@ -11407,6 +11464,11 @@ def mul(x, y, x_num_col_dims=1, y_num_col_dims=1, name=None): """ + inputs = {"X": [x], "Y": [y]} + attrs = {"x_num_col_dims": x_num_col_dims, "y_num_col_dims": y_num_col_dims} + if in_dygraph_mode(): + outs = core.ops.mul(inputs, attrs) + return outs['Out'][0] helper = LayerHelper("mul", **locals()) check_type_and_dtype(x, 'x', Variable, ['float16', 'float32', 'float64'], @@ -11420,14 +11482,8 @@ def mul(x, y, x_num_col_dims=1, y_num_col_dims=1, name=None): name=name, dtype=x.dtype, persistable=False) helper.append_op( - type="mul", - inputs={"X": x, - "Y": y}, - attrs={ - "x_num_col_dims": x_num_col_dims, - "y_num_col_dims": y_num_col_dims - }, - outputs={"Out": out}) + type="mul", inputs={"X": x, + "Y": y}, attrs=attrs, outputs={"Out": out}) return out diff --git a/python/paddle/fluid/layers/tensor.py b/python/paddle/fluid/layers/tensor.py index 02ade74173f..675c9183824 100644 --- a/python/paddle/fluid/layers/tensor.py +++ b/python/paddle/fluid/layers/tensor.py @@ -16,7 +16,7 @@ from __future__ import print_function from six.moves import reduce from ..layer_helper import LayerHelper from ..param_attr import ParamAttr -from ..framework import convert_np_dtype_to_dtype_, in_dygraph_mode +from ..framework import convert_np_dtype_to_dtype_, in_dygraph_mode, _varbase_creator from ..framework import Variable from ..initializer import Constant, force_init_on_cpu from ..core import VarDesc @@ -552,6 +552,43 @@ def fill_constant(shape, dtype, value, force_cpu=False, out=None): shape = fluid.layers.fill_constant([1,2], "int32", 2) # shape=[2,2] data4 = fluid.layers.fill_constant(shape=shape, dtype='bool', value=True) # data4=[[True,True],[True,True]] """ + + def _contain_var(one_list): + for ele in one_list: + if isinstance(ele, Variable): + return True + return False + + attrs = { + 'value': float(value), + 'force_cpu': force_cpu or force_init_on_cpu() + } + + if convert_dtype(dtype) in ['int64', 'int32']: + attrs['str_value'] = str(int(value)) + else: + attrs['str_value'] = str(float(value)) + + if in_dygraph_mode(): + if isinstance(shape, (list, tuple)): + contain_var = _contain_var(shape) + if contain_var: + raise TypeError( + "The type of 'shape' in fill_constant must be list[int] or tuple(int) in Dygraph mode, but " + "received %s, which contains Variable." % type(shape)) + attrs['shape'] = shape + else: + raise TypeError( + "The type of 'shape' in fill_constant must be list[int] or tuple(int) in Dygraph mode, but " + "received %s." % type(shape)) + if out is None: + out = _varbase_creator(dtype=dtype) + attrs['dtype'] = out.dtype + outputs = {'Out': [out]} + outs = core.ops.fill_constant({}, attrs, outputs) + out.stop_gradient = True + return out + helper = LayerHelper("fill_constant", **locals()) check_dtype(dtype, 'create data type', ['bool', 'float16', 'float32', 'float64', 'int32', 'int64'], @@ -568,12 +605,6 @@ def fill_constant(shape, dtype, value, force_cpu=False, out=None): else: attrs['str_value'] = str(float(value)) - def _contain_var(one_list): - for ele in one_list: - if isinstance(ele, Variable): - return True - return False - def _get_attr_shape(list_shape): attr_shape = [] for idx, dim in enumerate(list_shape): diff --git a/python/paddle/fluid/optimizer.py b/python/paddle/fluid/optimizer.py index c05475045d9..caa665759a2 100644 --- a/python/paddle/fluid/optimizer.py +++ b/python/paddle/fluid/optimizer.py @@ -896,21 +896,30 @@ class MomentumOptimizer(Optimizer): velocity_acc = self._get_accumulator(self._velocity_acc_str, param_and_grad[0]) + attrs = {"mu": self._momentum, "use_nesterov": self._use_nesterov} + + inputs = { + "Param": [param_and_grad[0]], + "Grad": [param_and_grad[1]], + "Velocity": [velocity_acc], + "LearningRate": [self._create_param_lr(param_and_grad)] + } + + outputs = { + "ParamOut": [param_and_grad[0]], + "VelocityOut": [velocity_acc] + } + + if framework.in_dygraph_mode(): + core.ops.momentum(inputs, attrs, outputs) + return None + # create the momentum optimize op momentum_op = block.append_op( type=self.type, - inputs={ - "Param": param_and_grad[0], - "Grad": param_and_grad[1], - "Velocity": velocity_acc, - "LearningRate": self._create_param_lr(param_and_grad) - }, - outputs={ - "ParamOut": param_and_grad[0], - "VelocityOut": velocity_acc - }, - attrs={"mu": self._momentum, - "use_nesterov": self._use_nesterov}, + inputs=inputs, + outputs=outputs, + attrs=attrs, stop_gradient=True) return momentum_op diff --git a/python/paddle/fluid/regularizer.py b/python/paddle/fluid/regularizer.py index 6c93f5c5060..d6774faf686 100644 --- a/python/paddle/fluid/regularizer.py +++ b/python/paddle/fluid/regularizer.py @@ -15,6 +15,7 @@ from __future__ import print_function from . import framework +from .framework import in_dygraph_mode, _varbase_creator from . import core __all__ = ['L1Decay', 'L2Decay', 'L1DecayRegularizer', 'L2DecayRegularizer'] @@ -74,10 +75,12 @@ def append_regularization_ops(parameters_and_grads, regularization=None): lod_level=param.lod_level, type=core.VarDesc.VarType.LOD_TENSOR) - grad.block.append_op( - type='sum', - inputs={"X": [grad, regularization_term]}, - outputs={"Out": new_grad}) + inputs = {"X": [grad, regularization_term]} + outputs = {"Out": [new_grad]} + if in_dygraph_mode(): + core.ops.sum(inputs, {}, outputs) + else: + grad.block.append_op(type='sum', inputs=inputs, outputs=outputs) params_and_grads.append((param, new_grad)) @@ -165,20 +168,24 @@ class L2DecayRegularizer(WeightDecayRegularizer): assert isinstance(param, framework.Parameter) assert isinstance(block, framework.Block) + inputs = {"X": [param]} + attrs = {"scale": self._regularization_coeff} + if framework.in_dygraph_mode(): - decay = block.create_var(dtype=param.dtype, shape=param.shape) + outs = core.ops.scale(inputs, attrs) + return outs['Out'][0] else: decay = block.create_var( dtype=param.dtype, shape=param.shape, lod_level=param.lod_level) - # Append Op to calculate decay - block.append_op( - type='scale', - inputs={"X": param}, - outputs={"Out": decay}, - attrs={"scale": self._regularization_coeff}) + # Append Op to calculate decay + block.append_op( + type='scale', + inputs={"X": param}, + outputs={"Out": decay}, + attrs={"scale": self._regularization_coeff}) - return decay + return decay def __str__(self): return "L2Decay, regularization_coeff=%f" % self._regularization_coeff diff --git a/python/paddle/fluid/tests/unittests/test_imperative_resnet.py b/python/paddle/fluid/tests/unittests/test_imperative_resnet.py index c123e0254d6..9f609355b13 100644 --- a/python/paddle/fluid/tests/unittests/test_imperative_resnet.py +++ b/python/paddle/fluid/tests/unittests/test_imperative_resnet.py @@ -27,6 +27,8 @@ from test_imperative_base import new_program_scope from utils import DyGraphProgramDescTracerTestHelper, is_equal_program from paddle.fluid.dygraph import TracedLayer +#NOTE(zhiqiu): run with FLAGS_cudnn_deterministic=1 + batch_size = 8 train_parameters = { "input_size": [3, 224, 224], -- GitLab