diff --git a/paddle/fluid/operators/CMakeLists.txt b/paddle/fluid/operators/CMakeLists.txt index 40ed5869c276522dde65cb7028e553f0443e5d62..23509773fa9e0697159f0365cc21ba84fb0ab1bf 100644 --- a/paddle/fluid/operators/CMakeLists.txt +++ b/paddle/fluid/operators/CMakeLists.txt @@ -91,7 +91,7 @@ set(COMMON_OP_DEPS ${COMMON_OP_DEPS} selected_rows_functor selected_rows lod_ten set(COMMON_OP_DEPS ${COMMON_OP_DEPS} dynload_warpctc) set(COMMON_OP_DEPS ${COMMON_OP_DEPS} sequence_padding sequence_scale cos_sim_functor memory jit_kernel_helper concat_and_split cross_entropy softmax vol2col im2col sampler sample_prob tree2col) set(COMMON_OP_DEPS ${COMMON_OP_DEPS} sequence2batch lstm_compute matrix_bit_code gru_compute activation_functions beam_search fc matrix_inverse) -set(COMMON_OP_DEPS ${COMMON_OP_DEPS} box_wrapper) +set(COMMON_OP_DEPS ${COMMON_OP_DEPS} box_wrapper boost) if (WITH_GPU) set(COMMON_OP_DEPS ${COMMON_OP_DEPS} depthwise_conv prelu bert_encoder_functor) endif() diff --git a/paddle/fluid/operators/run_program_op.h b/paddle/fluid/operators/run_program_op.h index 505ce4c09681d3405227b0e2e8b8b1209a3d359f..ae09e87473d23d16a5f73b42b4f5a4e8e641c0bc 100644 --- a/paddle/fluid/operators/run_program_op.h +++ b/paddle/fluid/operators/run_program_op.h @@ -17,10 +17,12 @@ limitations under the License. */ #include #include #include +#include #include #include #include "paddle/fluid/framework/executor.h" +#include "paddle/fluid/framework/op_desc.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/program_desc.h" @@ -149,14 +151,46 @@ static void ShareVarsFromScope(const std::vector &vars, } } -static void AppendSkipDeletionVars( - std::vector *all_vars, - const std::vector &append_vars) { +static void AppendSkipDeletionVars(const std::vector &append_vars, + std::vector *all_vars) { for (auto &var : append_vars) { all_vars->emplace_back(var); } } +static void AppendSafeEagerDeletionSkipVars( + const framework::ProgramDesc &program, + std::vector *skip_vars) { + const framework::BlockDesc &block = program.Block(0); + const std::vector &all_ops = block.AllOps(); + + std::unordered_set grad_op_output; + std::unordered_set grad_op_input; + for (const framework::OpDesc *op : all_ops) { + int op_role = BOOST_GET_CONST( + int, op->GetAttr(framework::OpProtoAndCheckerMaker::OpRoleAttrName())); + if ((op_role & static_cast(framework::OpRole::kBackward)) == 0) { + continue; + } + + for (const std::string &in_arg_name : op->InputArgumentNames()) { + grad_op_input.emplace(in_arg_name); + } + for (const std::string &out_arg_name : op->OutputArgumentNames()) { + grad_op_output.emplace(out_arg_name); + } + } + + // For the grad op input variables, if it is not output of grad_op, it may + // be output of forward op and we should set the variables as skip_var to + // prevent it being deleted when grad op is called multiple times. + for (const std::string &var_name : grad_op_input) { + if (grad_op_output.find(var_name) == grad_op_output.end()) { + skip_vars->emplace_back(var_name); + } + } +} + } // namespace details template @@ -192,7 +226,7 @@ class RunProgramOpKernel : public framework::OpKernel { // skip delete vars std::vector skip_vars; - details::AppendSkipDeletionVars(&skip_vars, output_var_names); + details::AppendSkipDeletionVars(output_var_names, &skip_vars); VLOG(2) << "Prepare to skip " << skip_vars.size() << " var(s): " << string::join_strings(skip_vars, ' '); @@ -261,20 +295,21 @@ class RunProgramGradOpKernel : public framework::OpKernel { out_scope_vec->size(), 1, platform::errors::InvalidArgument( "The OutScope of RunProgramGradOp should only hold one scope.")); + auto &scope = *(out_scope_vec->front()); // Step 2. prepare executor and scope framework::Executor exe(ctx.GetPlace()); // skip delete vars std::vector skip_vars; - details::AppendSkipDeletionVars(&skip_vars, input_grad_var_names); - details::AppendSkipDeletionVars(&skip_vars, param_grad_names); + details::AppendSkipDeletionVars(input_grad_var_names, &skip_vars); + details::AppendSkipDeletionVars(param_grad_names, &skip_vars); + details::AppendSafeEagerDeletionSkipVars(*program, &skip_vars); VLOG(2) << "Prepare to skip " << skip_vars.size() << " var(s): " << string::join_strings(skip_vars, ' '); auto exe_ctx = exe.Prepare(*program, 0, skip_vars); - auto &scope = *(out_scope_vec->front()); details::ShareVarsIntoScope(output_grad_vars, output_grad_var_names, &scope); diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/simnet_dygraph_model.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/simnet_dygraph_model.py new file mode 100644 index 0000000000000000000000000000000000000000..2520b3722882d48f0ed021f82b271b51a7ffcd16 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/simnet_dygraph_model.py @@ -0,0 +1,516 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle.fluid as fluid +import paddle.fluid.param_attr as attr + +from functools import reduce +from paddle.fluid.dygraph import declarative, to_variable +from paddle.fluid.dygraph import Embedding, Layer, Linear + + +class EmbeddingLayer(object): + """ + Embedding Layer class + """ + + def __init__(self, dict_size, emb_dim, name="emb", padding_idx=None): + """ + initialize + """ + self.dict_size = dict_size + self.emb_dim = emb_dim + self.name = name + self.padding_idx = padding_idx + + def ops(self): + """ + operation + """ + # TODO(huihuangzheng): The original code set the is_sparse=True, but it + # causes crush in dy2stat. Set it to True after fixing it. + emb = Embedding( + size=[self.dict_size, self.emb_dim], + is_sparse=False, + padding_idx=self.padding_idx, + param_attr=attr.ParamAttr( + name=self.name, initializer=fluid.initializer.Xavier())) + + return emb + + +class FCLayer(object): + """ + Fully Connect Layer class + """ + + def __init__(self, fc_dim, act, name="fc"): + """ + initialize + """ + self.fc_dim = fc_dim + self.act = act + self.name = name + + def ops(self): + """ + operation + """ + fc = FC(size=self.fc_dim, + param_attr=attr.ParamAttr(name="%s.w" % self.name), + bias_attr=attr.ParamAttr(name="%s.b" % self.name), + act=self.act) + return fc + + +class ConcatLayer(object): + """ + Connection Layer class + """ + + def __init__(self, axis): + """ + initialize + """ + self.axis = axis + + def ops(self, inputs): + """ + operation + """ + concat = fluid.layers.concat(inputs, axis=self.axis) + return concat + + +class ReduceMeanLayer(object): + """ + Reduce Mean Layer class + """ + + def __init__(self): + """ + initialize + """ + pass + + def ops(self, input): + """ + operation + """ + mean = fluid.layers.reduce_mean(input) + return mean + + +class CosSimLayer(object): + """ + Cos Similarly Calculate Layer + """ + + def __init__(self): + """ + initialize + """ + pass + + def ops(self, x, y): + """ + operation + """ + sim = fluid.layers.cos_sim(x, y) + return sim + + +class ElementwiseMaxLayer(object): + """ + Elementwise Max Layer class + """ + + def __init__(self): + """ + initialize + """ + pass + + def ops(self, x, y): + """ + operation + """ + max = fluid.layers.elementwise_max(x, y) + return max + + +class ElementwiseAddLayer(object): + """ + Elementwise Add Layer class + """ + + def __init__(self): + """ + initialize + """ + pass + + def ops(self, x, y): + """ + operation + """ + add = fluid.layers.elementwise_add(x, y) + return add + + +class ElementwiseSubLayer(object): + """ + Elementwise Add Layer class + """ + + def __init__(self): + """ + initialize + """ + pass + + def ops(self, x, y): + """ + operation + """ + sub = fluid.layers.elementwise_sub(x, y) + return sub + + +class ConstantLayer(object): + """ + Generate A Constant Layer class + """ + + def __init__(self): + """ + initialize + """ + pass + + def ops(self, input, shape, dtype, value): + """ + operation + """ + shape = list(shape) + input_shape = fluid.layers.shape(input) + shape[0] = input_shape[0] + constant = fluid.layers.fill_constant(shape, dtype, value) + return constant + + +class SoftsignLayer(object): + """ + Softsign Layer class + """ + + def __init__(self): + """ + initialize + """ + pass + + def ops(self, input): + """ + operation + """ + softsign = fluid.layers.softsign(input) + return softsign + + +class FC(Layer): + """ + This interface is used to construct a callable object of the ``FC`` class. + For more details, refer to code examples. + It creates a fully connected layer in the network. It can take + one or multiple ``Tensor`` as its inputs. It creates a Variable called weights for each input tensor, + which represents a fully connected weight matrix from each input unit to + each output unit. The fully connected layer multiplies each input tensor + with its corresponding weight to produce an output Tensor with shape [N, `size`], + where N is batch size. If multiple input tensors are given, the results of + multiple output tensors with shape [N, `size`] will be summed up. If ``bias_attr`` + is not None, a bias variable will be created and added to the output. + Finally, if ``act`` is not None, it will be applied to the output as well. + When the input is single ``Tensor`` : + .. math:: + Out = Act({XW + b}) + When the input are multiple ``Tensor`` : + .. math:: + Out = Act({\sum_{i=0}^{N-1}X_iW_i + b}) + In the above equation: + * :math:`N`: Number of the input. N equals to len(input) if input is list of ``Tensor`` . + * :math:`X_i`: The i-th input ``Tensor`` . + * :math:`W_i`: The i-th weights matrix corresponding i-th input tensor. + * :math:`b`: The bias parameter created by this layer (if needed). + * :math:`Act`: The activation function. + * :math:`Out`: The output ``Tensor`` . + See below for an example. + .. code-block:: text + Given: + data_1.data = [[[0.1, 0.2]]] + data_1.shape = (1, 1, 2) # 1 is batch_size + data_2.data = [[[0.1, 0.2, 0.3]]] + data_2.shape = (1, 1, 3) # 1 is batch_size + fc = FC("fc", 2, num_flatten_dims=2) + out = fc(input=[data_1, data_2]) + Then: + out.data = [[[0.182996 -0.474117]]] + out.shape = (1, 1, 2) + Parameters: + + size(int): The number of output units in this layer. + num_flatten_dims (int, optional): The fc layer can accept an input tensor with more than + two dimensions. If this happens, the multi-dimension tensor will first be flattened + into a 2-dimensional matrix. The parameter `num_flatten_dims` determines how the input + tensor is flattened: the first `num_flatten_dims` (inclusive, index starts from 1) + dimensions will be flatten to form the first dimension of the final matrix (height of + the matrix), and the rest `rank(X) - num_flatten_dims` dimensions are flattened to + form the second dimension of the final matrix (width of the matrix). For example, suppose + `X` is a 5-dimensional tensor with a shape [2, 3, 4, 5, 6], and `num_flatten_dims` = 3. + Then, the flattened matrix will have a shape [2 x 3 x 4, 5 x 6] = [24, 30]. Default: 1 + param_attr (ParamAttr or list of ParamAttr, optional): The parameter attribute for learnable + weights(Parameter) of this layer. Default: None. + bias_attr (ParamAttr or list of ParamAttr, optional): The attribute for the bias + of this layer. If it is set to False, no bias will be added to the output units. + If it is set to None, the bias is initialized zero. Default: None. + act (str, optional): Activation to be applied to the output of this layer. Default: None. + is_test(bool, optional): A flag indicating whether execution is in test phase. Default: False. + dtype(str, optional): Dtype used for weight, it can be "float32" or "float64". Default: "float32". + Attribute: + **weight** (list of Parameter): the learnable weights of this layer. + **bias** (Parameter or None): the learnable bias of this layer. + Returns: + None + + Examples: + .. code-block:: python + from paddle.fluid.dygraph.base import to_variable + import paddle.fluid as fluid + from paddle.fluid.dygraph import FC + import numpy as np + data = np.random.uniform(-1, 1, [30, 10, 32]).astype('float32') + with fluid.dygraph.guard(): + fc = FC("fc", 64, num_flatten_dims=2) + data = to_variable(data) + conv = fc(data) + """ + + def __init__(self, + size, + num_flatten_dims=1, + param_attr=None, + bias_attr=None, + act=None, + is_test=False, + dtype="float32"): + super(FC, self).__init__(dtype) + + self._size = size + self._num_flatten_dims = num_flatten_dims + self._dtype = dtype + self._param_attr = param_attr + self._bias_attr = bias_attr + self._act = act + self.__w = list() + + def _build_once(self, input): + i = 0 + for inp, param in self._helper.iter_inputs_and_params(input, + self._param_attr): + input_shape = inp.shape + + param_shape = [ + reduce(lambda a, b: a * b, input_shape[self._num_flatten_dims:], + 1) + ] + [self._size] + self.__w.append( + self.add_parameter( + '_w%d' % i, + self.create_parameter( + attr=param, + shape=param_shape, + dtype=self._dtype, + is_bias=False))) + i += 1 + + size = list([self._size]) + self._b = self.create_parameter( + attr=self._bias_attr, shape=size, dtype=self._dtype, is_bias=True) + + # TODO(songyouwei): We should remove _w property + @property + def _w(self, i=0): + return self.__w[i] + + @_w.setter + def _w(self, value, i=0): + assert isinstance(self.__w[i], Variable) + self.__w[i].set_value(value) + + @property + def weight(self): + if len(self.__w) > 1: + return self.__w + else: + return self.__w[0] + + @weight.setter + def weight(self, value): + if len(self.__w) == 1: + self.__w[0] = value + + @property + def bias(self): + return self._b + + @bias.setter + def bias(self, value): + self._b = value + + def forward(self, input): + mul_results = list() + i = 0 + for inp, param in self._helper.iter_inputs_and_params(input, + self._param_attr): + tmp = self._helper.create_variable_for_type_inference(self._dtype) + self._helper.append_op( + type="mul", + inputs={"X": inp, + "Y": self.__w[i]}, + outputs={"Out": tmp}, + attrs={ + "x_num_col_dims": self._num_flatten_dims, + "y_num_col_dims": 1 + }) + i += 1 + mul_results.append(tmp) + + if len(mul_results) == 1: + pre_bias = mul_results[0] + else: + pre_bias = self._helper.create_variable_for_type_inference( + self._dtype) + self._helper.append_op( + type="sum", + inputs={"X": mul_results}, + outputs={"Out": pre_bias}, + attrs={"use_mkldnn": False}) + + if self._b is not None: + pre_activation = self._helper.create_variable_for_type_inference( + dtype=self._dtype) + self._helper.append_op( + type='elementwise_add', + inputs={'X': [pre_bias], + 'Y': [self._b]}, + outputs={'Out': [pre_activation]}, + attrs={'axis': self._num_flatten_dims}) + else: + pre_activation = pre_bias + # Currently, we don't support inplace in dygraph mode + return self._helper.append_activation(pre_activation, act=self._act) + + +class HingeLoss(object): + """ + Hing Loss Calculate class + """ + + def __init__(self, conf_dict): + """ + initialize + """ + self.margin = conf_dict["loss"]["margin"] + + def compute(self, pos, neg): + """ + compute loss + """ + elementwise_max = ElementwiseMaxLayer() + elementwise_add = ElementwiseAddLayer() + elementwise_sub = ElementwiseSubLayer() + constant = ConstantLayer() + reduce_mean = ReduceMeanLayer() + loss = reduce_mean.ops( + elementwise_max.ops( + constant.ops(neg, neg.shape, "float32", 0.0), + elementwise_add.ops( + elementwise_sub.ops(neg, pos), + constant.ops(neg, neg.shape, "float32", self.margin)))) + return loss + + +class BOW(Layer): + """ + BOW + """ + + def __init__(self, conf_dict): + """ + initialize + """ + super(BOW, self).__init__() + self.dict_size = conf_dict["dict_size"] + self.task_mode = conf_dict["task_mode"] + self.emb_dim = conf_dict["net"]["emb_dim"] + self.bow_dim = conf_dict["net"]["bow_dim"] + self.seq_len = conf_dict["seq_len"] + self.emb_layer = EmbeddingLayer(self.dict_size, self.emb_dim, + "emb").ops() + self.bow_layer = Linear(self.bow_dim, self.bow_dim) + self.bow_layer_po = FCLayer(self.bow_dim, None, "fc").ops() + self.softmax_layer = FCLayer(2, "softmax", "cos_sim").ops() + + @declarative + def forward(self, left, right): + """ + Forward network + """ + + # embedding layer + left_emb = self.emb_layer(left) + right_emb = self.emb_layer(right) + left_emb = fluid.layers.reshape( + left_emb, shape=[-1, self.seq_len, self.bow_dim]) + right_emb = fluid.layers.reshape( + right_emb, shape=[-1, self.seq_len, self.bow_dim]) + + bow_left = fluid.layers.reduce_sum(left_emb, dim=1) + bow_right = fluid.layers.reduce_sum(right_emb, dim=1) + softsign_layer = SoftsignLayer() + left_soft = softsign_layer.ops(bow_left) + right_soft = softsign_layer.ops(bow_right) + + left_bow = self.bow_layer(left_soft) + right_bow = self.bow_layer(right_soft) + cos_sim_layer = CosSimLayer() + pred = cos_sim_layer.ops(left_bow, right_bow) + return left_bow, pred + + # TODO(huihuangzheng): uncomment the following return statements after + # we fix it. + # + # matching layer + #if self.task_mode == "pairwise": + # left_bow = self.bow_layer(left_soft) + # right_bow = self.bow_layer(right_soft) + # cos_sim_layer = CosSimLayer() + # pred = cos_sim_layer.ops(left_bow, right_bow) + # return left_bow, pred + #else: + # concat_layer = ConcatLayer(1) + # concat = concat_layer.ops([left_soft, right_soft]) + # concat_fc = self.bow_layer_po(concat) + # pred = self.softmax_layer(concat_fc) + # return left_soft, pred diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_simnet.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_simnet.py new file mode 100644 index 0000000000000000000000000000000000000000..94b9bb86be241cb3a08e3dd8e37942bd1202ab91 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_simnet.py @@ -0,0 +1,174 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import numpy as np +import paddle +import paddle.fluid as fluid +import random +import unittest + +from paddle.fluid.dygraph import ProgramTranslator +from simnet_dygraph_model import BOW, HingeLoss + +SEED = 102 +random.seed(SEED) + + +def create_conf_dict(): + conf_dict = {} + conf_dict["task_mode"] = "train" + conf_dict["net"] = {"emb_dim": 128, "bow_dim": 128, "hidden_dim": 128} + conf_dict["loss"] = {"margin": 0.1} + return conf_dict + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--batch_size", + type=int, + default=32, + help="Total examples' number in batch for training.") + parser.add_argument( + "--seq_len", type=int, default=32, help="The length of each sentence.") + parser.add_argument( + "--epoch", type=int, default=1, help="The number of training epoch.") + parser.add_argument( + "--fake_sample_size", + type=int, + default=128, + help="The number of samples of fake data.") + args = parser.parse_args([]) + return args + + +args = parse_args() + + +def fake_vocabulary(): + vocab = {} + vocab[""] = 0 + for i in range(26): + c = chr(ord('a') + i) + vocab[c] = i + 1 + return vocab + + +vocab = fake_vocabulary() + + +class FakeReaderProcessor(object): + def __init__(self, args, vocab): + self.vocab = vocab + self.seq_len = args.seq_len + self.sample_size = args.fake_sample_size + self.data_samples = [] + for i in range(self.sample_size): + query = [random.randint(0, 26) for i in range(self.seq_len)] + pos_title = query[:] + neg_title = [26 - q for q in query] + self.data_samples.append( + np.array([query, pos_title, neg_title]).astype(np.int64)) + + def get_reader(self, mode, epoch=0): + def reader_with_pairwise(): + if mode == "train": + for i in range(self.sample_size): + yield self.data_samples[i] + + return reader_with_pairwise + + +simnet_process = FakeReaderProcessor(args, vocab) + + +def train(conf_dict, to_static): + """ + train process + """ + program_translator = ProgramTranslator() + program_translator.enable(to_static) + + # Get device + if fluid.is_compiled_with_cuda(): + place = fluid.CUDAPlace(0) + else: + place = fluid.CPUPlace() + + with fluid.dygraph.guard(place): + fluid.default_startup_program().random_seed = SEED + fluid.default_main_program().random_seed = SEED + + conf_dict['dict_size'] = len(vocab) + conf_dict['seq_len'] = args.seq_len + + net = BOW(conf_dict) + loss = HingeLoss(conf_dict) + optimizer = fluid.optimizer.AdamOptimizer( + learning_rate=0.001, + beta1=0.9, + beta2=0.999, + epsilon=1e-08, + parameter_list=net.parameters()) + + metric = fluid.metrics.Auc(name="auc") + + global_step = 0 + losses = [] + + train_loader = fluid.io.DataLoader.from_generator( + capacity=16, + return_list=True, + iterable=True, + use_double_buffer=True) + get_train_examples = simnet_process.get_reader( + "train", epoch=args.epoch) + train_loader.set_sample_list_generator( + paddle.batch( + get_train_examples, batch_size=args.batch_size), place) + + for left, pos_right, neg_right in train_loader(): + left = fluid.layers.reshape(left, shape=[-1, 1]) + pos_right = fluid.layers.reshape(pos_right, shape=[-1, 1]) + neg_right = fluid.layers.reshape(neg_right, shape=[-1, 1]) + net.train() + global_step += 1 + left_feat, pos_score = net(left, pos_right) + pred = pos_score + _, neg_score = net(left, neg_right) + avg_cost = loss.compute(pos_score, neg_score) + #avg_cost = loss.compute(pos_score, pos_score) + losses.append(np.mean(avg_cost.numpy())) + avg_cost.backward() + optimizer.minimize(avg_cost) + net.clear_gradients() + return losses + + +class TestSimnet(unittest.TestCase): + def test_dygraph_static_same_loss(self): + if fluid.is_compiled_with_cuda(): + fluid.set_flags({"FLAGS_cudnn_deterministic": True}) + conf_dict = create_conf_dict() + dygraph_loss = train(conf_dict, to_static=False) + static_loss = train(conf_dict, to_static=True) + + self.assertEqual(len(dygraph_loss), len(static_loss)) + for i in range(len(dygraph_loss)): + self.assertAlmostEqual(dygraph_loss[i], static_loss[i]) + + +if __name__ == '__main__': + unittest.main()