From f9ac5fb9925a950ae384eb65e972e8a548624d6c Mon Sep 17 00:00:00 2001 From: Huihuang Zheng Date: Mon, 13 Jul 2020 09:55:24 +0800 Subject: [PATCH] [Dy2stat] Fix Memory Optimization in run_program_op and Add SimNet as Unit Test (#25383) Add Similarity Net as unit test. During the unit test, we found three problems: 1. The run_program_op has memory optimization error when running dy2stat net multiple times. 2. The support for SelectedRows can cause problem in dy2stat. 3. The return grammar has problem. This PR fixes the 1. problem but modify codes for the 2. 3. problems to make PR smaller. I will fix those two problems in the next PR(s) --- paddle/fluid/operators/CMakeLists.txt | 2 +- paddle/fluid/operators/run_program_op.h | 49 +- .../dygraph_to_static/simnet_dygraph_model.py | 516 ++++++++++++++++++ .../dygraph_to_static/test_simnet.py | 174 ++++++ 4 files changed, 733 insertions(+), 8 deletions(-) create mode 100644 python/paddle/fluid/tests/unittests/dygraph_to_static/simnet_dygraph_model.py create mode 100644 python/paddle/fluid/tests/unittests/dygraph_to_static/test_simnet.py diff --git a/paddle/fluid/operators/CMakeLists.txt b/paddle/fluid/operators/CMakeLists.txt index 40ed5869c2..23509773fa 100644 --- a/paddle/fluid/operators/CMakeLists.txt +++ b/paddle/fluid/operators/CMakeLists.txt @@ -91,7 +91,7 @@ set(COMMON_OP_DEPS ${COMMON_OP_DEPS} selected_rows_functor selected_rows lod_ten set(COMMON_OP_DEPS ${COMMON_OP_DEPS} dynload_warpctc) set(COMMON_OP_DEPS ${COMMON_OP_DEPS} sequence_padding sequence_scale cos_sim_functor memory jit_kernel_helper concat_and_split cross_entropy softmax vol2col im2col sampler sample_prob tree2col) set(COMMON_OP_DEPS ${COMMON_OP_DEPS} sequence2batch lstm_compute matrix_bit_code gru_compute activation_functions beam_search fc matrix_inverse) -set(COMMON_OP_DEPS ${COMMON_OP_DEPS} box_wrapper) +set(COMMON_OP_DEPS ${COMMON_OP_DEPS} box_wrapper boost) if (WITH_GPU) set(COMMON_OP_DEPS ${COMMON_OP_DEPS} depthwise_conv prelu bert_encoder_functor) endif() diff --git a/paddle/fluid/operators/run_program_op.h b/paddle/fluid/operators/run_program_op.h index 505ce4c096..ae09e87473 100644 --- a/paddle/fluid/operators/run_program_op.h +++ b/paddle/fluid/operators/run_program_op.h @@ -17,10 +17,12 @@ limitations under the License. */ #include #include #include +#include #include #include #include "paddle/fluid/framework/executor.h" +#include "paddle/fluid/framework/op_desc.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/program_desc.h" @@ -149,14 +151,46 @@ static void ShareVarsFromScope(const std::vector &vars, } } -static void AppendSkipDeletionVars( - std::vector *all_vars, - const std::vector &append_vars) { +static void AppendSkipDeletionVars(const std::vector &append_vars, + std::vector *all_vars) { for (auto &var : append_vars) { all_vars->emplace_back(var); } } +static void AppendSafeEagerDeletionSkipVars( + const framework::ProgramDesc &program, + std::vector *skip_vars) { + const framework::BlockDesc &block = program.Block(0); + const std::vector &all_ops = block.AllOps(); + + std::unordered_set grad_op_output; + std::unordered_set grad_op_input; + for (const framework::OpDesc *op : all_ops) { + int op_role = BOOST_GET_CONST( + int, op->GetAttr(framework::OpProtoAndCheckerMaker::OpRoleAttrName())); + if ((op_role & static_cast(framework::OpRole::kBackward)) == 0) { + continue; + } + + for (const std::string &in_arg_name : op->InputArgumentNames()) { + grad_op_input.emplace(in_arg_name); + } + for (const std::string &out_arg_name : op->OutputArgumentNames()) { + grad_op_output.emplace(out_arg_name); + } + } + + // For the grad op input variables, if it is not output of grad_op, it may + // be output of forward op and we should set the variables as skip_var to + // prevent it being deleted when grad op is called multiple times. + for (const std::string &var_name : grad_op_input) { + if (grad_op_output.find(var_name) == grad_op_output.end()) { + skip_vars->emplace_back(var_name); + } + } +} + } // namespace details template @@ -192,7 +226,7 @@ class RunProgramOpKernel : public framework::OpKernel { // skip delete vars std::vector skip_vars; - details::AppendSkipDeletionVars(&skip_vars, output_var_names); + details::AppendSkipDeletionVars(output_var_names, &skip_vars); VLOG(2) << "Prepare to skip " << skip_vars.size() << " var(s): " << string::join_strings(skip_vars, ' '); @@ -261,20 +295,21 @@ class RunProgramGradOpKernel : public framework::OpKernel { out_scope_vec->size(), 1, platform::errors::InvalidArgument( "The OutScope of RunProgramGradOp should only hold one scope.")); + auto &scope = *(out_scope_vec->front()); // Step 2. prepare executor and scope framework::Executor exe(ctx.GetPlace()); // skip delete vars std::vector skip_vars; - details::AppendSkipDeletionVars(&skip_vars, input_grad_var_names); - details::AppendSkipDeletionVars(&skip_vars, param_grad_names); + details::AppendSkipDeletionVars(input_grad_var_names, &skip_vars); + details::AppendSkipDeletionVars(param_grad_names, &skip_vars); + details::AppendSafeEagerDeletionSkipVars(*program, &skip_vars); VLOG(2) << "Prepare to skip " << skip_vars.size() << " var(s): " << string::join_strings(skip_vars, ' '); auto exe_ctx = exe.Prepare(*program, 0, skip_vars); - auto &scope = *(out_scope_vec->front()); details::ShareVarsIntoScope(output_grad_vars, output_grad_var_names, &scope); diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/simnet_dygraph_model.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/simnet_dygraph_model.py new file mode 100644 index 0000000000..2520b37228 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/simnet_dygraph_model.py @@ -0,0 +1,516 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle.fluid as fluid +import paddle.fluid.param_attr as attr + +from functools import reduce +from paddle.fluid.dygraph import declarative, to_variable +from paddle.fluid.dygraph import Embedding, Layer, Linear + + +class EmbeddingLayer(object): + """ + Embedding Layer class + """ + + def __init__(self, dict_size, emb_dim, name="emb", padding_idx=None): + """ + initialize + """ + self.dict_size = dict_size + self.emb_dim = emb_dim + self.name = name + self.padding_idx = padding_idx + + def ops(self): + """ + operation + """ + # TODO(huihuangzheng): The original code set the is_sparse=True, but it + # causes crush in dy2stat. Set it to True after fixing it. + emb = Embedding( + size=[self.dict_size, self.emb_dim], + is_sparse=False, + padding_idx=self.padding_idx, + param_attr=attr.ParamAttr( + name=self.name, initializer=fluid.initializer.Xavier())) + + return emb + + +class FCLayer(object): + """ + Fully Connect Layer class + """ + + def __init__(self, fc_dim, act, name="fc"): + """ + initialize + """ + self.fc_dim = fc_dim + self.act = act + self.name = name + + def ops(self): + """ + operation + """ + fc = FC(size=self.fc_dim, + param_attr=attr.ParamAttr(name="%s.w" % self.name), + bias_attr=attr.ParamAttr(name="%s.b" % self.name), + act=self.act) + return fc + + +class ConcatLayer(object): + """ + Connection Layer class + """ + + def __init__(self, axis): + """ + initialize + """ + self.axis = axis + + def ops(self, inputs): + """ + operation + """ + concat = fluid.layers.concat(inputs, axis=self.axis) + return concat + + +class ReduceMeanLayer(object): + """ + Reduce Mean Layer class + """ + + def __init__(self): + """ + initialize + """ + pass + + def ops(self, input): + """ + operation + """ + mean = fluid.layers.reduce_mean(input) + return mean + + +class CosSimLayer(object): + """ + Cos Similarly Calculate Layer + """ + + def __init__(self): + """ + initialize + """ + pass + + def ops(self, x, y): + """ + operation + """ + sim = fluid.layers.cos_sim(x, y) + return sim + + +class ElementwiseMaxLayer(object): + """ + Elementwise Max Layer class + """ + + def __init__(self): + """ + initialize + """ + pass + + def ops(self, x, y): + """ + operation + """ + max = fluid.layers.elementwise_max(x, y) + return max + + +class ElementwiseAddLayer(object): + """ + Elementwise Add Layer class + """ + + def __init__(self): + """ + initialize + """ + pass + + def ops(self, x, y): + """ + operation + """ + add = fluid.layers.elementwise_add(x, y) + return add + + +class ElementwiseSubLayer(object): + """ + Elementwise Add Layer class + """ + + def __init__(self): + """ + initialize + """ + pass + + def ops(self, x, y): + """ + operation + """ + sub = fluid.layers.elementwise_sub(x, y) + return sub + + +class ConstantLayer(object): + """ + Generate A Constant Layer class + """ + + def __init__(self): + """ + initialize + """ + pass + + def ops(self, input, shape, dtype, value): + """ + operation + """ + shape = list(shape) + input_shape = fluid.layers.shape(input) + shape[0] = input_shape[0] + constant = fluid.layers.fill_constant(shape, dtype, value) + return constant + + +class SoftsignLayer(object): + """ + Softsign Layer class + """ + + def __init__(self): + """ + initialize + """ + pass + + def ops(self, input): + """ + operation + """ + softsign = fluid.layers.softsign(input) + return softsign + + +class FC(Layer): + """ + This interface is used to construct a callable object of the ``FC`` class. + For more details, refer to code examples. + It creates a fully connected layer in the network. It can take + one or multiple ``Tensor`` as its inputs. It creates a Variable called weights for each input tensor, + which represents a fully connected weight matrix from each input unit to + each output unit. The fully connected layer multiplies each input tensor + with its corresponding weight to produce an output Tensor with shape [N, `size`], + where N is batch size. If multiple input tensors are given, the results of + multiple output tensors with shape [N, `size`] will be summed up. If ``bias_attr`` + is not None, a bias variable will be created and added to the output. + Finally, if ``act`` is not None, it will be applied to the output as well. + When the input is single ``Tensor`` : + .. math:: + Out = Act({XW + b}) + When the input are multiple ``Tensor`` : + .. math:: + Out = Act({\sum_{i=0}^{N-1}X_iW_i + b}) + In the above equation: + * :math:`N`: Number of the input. N equals to len(input) if input is list of ``Tensor`` . + * :math:`X_i`: The i-th input ``Tensor`` . + * :math:`W_i`: The i-th weights matrix corresponding i-th input tensor. + * :math:`b`: The bias parameter created by this layer (if needed). + * :math:`Act`: The activation function. + * :math:`Out`: The output ``Tensor`` . + See below for an example. + .. code-block:: text + Given: + data_1.data = [[[0.1, 0.2]]] + data_1.shape = (1, 1, 2) # 1 is batch_size + data_2.data = [[[0.1, 0.2, 0.3]]] + data_2.shape = (1, 1, 3) # 1 is batch_size + fc = FC("fc", 2, num_flatten_dims=2) + out = fc(input=[data_1, data_2]) + Then: + out.data = [[[0.182996 -0.474117]]] + out.shape = (1, 1, 2) + Parameters: + + size(int): The number of output units in this layer. + num_flatten_dims (int, optional): The fc layer can accept an input tensor with more than + two dimensions. If this happens, the multi-dimension tensor will first be flattened + into a 2-dimensional matrix. The parameter `num_flatten_dims` determines how the input + tensor is flattened: the first `num_flatten_dims` (inclusive, index starts from 1) + dimensions will be flatten to form the first dimension of the final matrix (height of + the matrix), and the rest `rank(X) - num_flatten_dims` dimensions are flattened to + form the second dimension of the final matrix (width of the matrix). For example, suppose + `X` is a 5-dimensional tensor with a shape [2, 3, 4, 5, 6], and `num_flatten_dims` = 3. + Then, the flattened matrix will have a shape [2 x 3 x 4, 5 x 6] = [24, 30]. Default: 1 + param_attr (ParamAttr or list of ParamAttr, optional): The parameter attribute for learnable + weights(Parameter) of this layer. Default: None. + bias_attr (ParamAttr or list of ParamAttr, optional): The attribute for the bias + of this layer. If it is set to False, no bias will be added to the output units. + If it is set to None, the bias is initialized zero. Default: None. + act (str, optional): Activation to be applied to the output of this layer. Default: None. + is_test(bool, optional): A flag indicating whether execution is in test phase. Default: False. + dtype(str, optional): Dtype used for weight, it can be "float32" or "float64". Default: "float32". + Attribute: + **weight** (list of Parameter): the learnable weights of this layer. + **bias** (Parameter or None): the learnable bias of this layer. + Returns: + None + + Examples: + .. code-block:: python + from paddle.fluid.dygraph.base import to_variable + import paddle.fluid as fluid + from paddle.fluid.dygraph import FC + import numpy as np + data = np.random.uniform(-1, 1, [30, 10, 32]).astype('float32') + with fluid.dygraph.guard(): + fc = FC("fc", 64, num_flatten_dims=2) + data = to_variable(data) + conv = fc(data) + """ + + def __init__(self, + size, + num_flatten_dims=1, + param_attr=None, + bias_attr=None, + act=None, + is_test=False, + dtype="float32"): + super(FC, self).__init__(dtype) + + self._size = size + self._num_flatten_dims = num_flatten_dims + self._dtype = dtype + self._param_attr = param_attr + self._bias_attr = bias_attr + self._act = act + self.__w = list() + + def _build_once(self, input): + i = 0 + for inp, param in self._helper.iter_inputs_and_params(input, + self._param_attr): + input_shape = inp.shape + + param_shape = [ + reduce(lambda a, b: a * b, input_shape[self._num_flatten_dims:], + 1) + ] + [self._size] + self.__w.append( + self.add_parameter( + '_w%d' % i, + self.create_parameter( + attr=param, + shape=param_shape, + dtype=self._dtype, + is_bias=False))) + i += 1 + + size = list([self._size]) + self._b = self.create_parameter( + attr=self._bias_attr, shape=size, dtype=self._dtype, is_bias=True) + + # TODO(songyouwei): We should remove _w property + @property + def _w(self, i=0): + return self.__w[i] + + @_w.setter + def _w(self, value, i=0): + assert isinstance(self.__w[i], Variable) + self.__w[i].set_value(value) + + @property + def weight(self): + if len(self.__w) > 1: + return self.__w + else: + return self.__w[0] + + @weight.setter + def weight(self, value): + if len(self.__w) == 1: + self.__w[0] = value + + @property + def bias(self): + return self._b + + @bias.setter + def bias(self, value): + self._b = value + + def forward(self, input): + mul_results = list() + i = 0 + for inp, param in self._helper.iter_inputs_and_params(input, + self._param_attr): + tmp = self._helper.create_variable_for_type_inference(self._dtype) + self._helper.append_op( + type="mul", + inputs={"X": inp, + "Y": self.__w[i]}, + outputs={"Out": tmp}, + attrs={ + "x_num_col_dims": self._num_flatten_dims, + "y_num_col_dims": 1 + }) + i += 1 + mul_results.append(tmp) + + if len(mul_results) == 1: + pre_bias = mul_results[0] + else: + pre_bias = self._helper.create_variable_for_type_inference( + self._dtype) + self._helper.append_op( + type="sum", + inputs={"X": mul_results}, + outputs={"Out": pre_bias}, + attrs={"use_mkldnn": False}) + + if self._b is not None: + pre_activation = self._helper.create_variable_for_type_inference( + dtype=self._dtype) + self._helper.append_op( + type='elementwise_add', + inputs={'X': [pre_bias], + 'Y': [self._b]}, + outputs={'Out': [pre_activation]}, + attrs={'axis': self._num_flatten_dims}) + else: + pre_activation = pre_bias + # Currently, we don't support inplace in dygraph mode + return self._helper.append_activation(pre_activation, act=self._act) + + +class HingeLoss(object): + """ + Hing Loss Calculate class + """ + + def __init__(self, conf_dict): + """ + initialize + """ + self.margin = conf_dict["loss"]["margin"] + + def compute(self, pos, neg): + """ + compute loss + """ + elementwise_max = ElementwiseMaxLayer() + elementwise_add = ElementwiseAddLayer() + elementwise_sub = ElementwiseSubLayer() + constant = ConstantLayer() + reduce_mean = ReduceMeanLayer() + loss = reduce_mean.ops( + elementwise_max.ops( + constant.ops(neg, neg.shape, "float32", 0.0), + elementwise_add.ops( + elementwise_sub.ops(neg, pos), + constant.ops(neg, neg.shape, "float32", self.margin)))) + return loss + + +class BOW(Layer): + """ + BOW + """ + + def __init__(self, conf_dict): + """ + initialize + """ + super(BOW, self).__init__() + self.dict_size = conf_dict["dict_size"] + self.task_mode = conf_dict["task_mode"] + self.emb_dim = conf_dict["net"]["emb_dim"] + self.bow_dim = conf_dict["net"]["bow_dim"] + self.seq_len = conf_dict["seq_len"] + self.emb_layer = EmbeddingLayer(self.dict_size, self.emb_dim, + "emb").ops() + self.bow_layer = Linear(self.bow_dim, self.bow_dim) + self.bow_layer_po = FCLayer(self.bow_dim, None, "fc").ops() + self.softmax_layer = FCLayer(2, "softmax", "cos_sim").ops() + + @declarative + def forward(self, left, right): + """ + Forward network + """ + + # embedding layer + left_emb = self.emb_layer(left) + right_emb = self.emb_layer(right) + left_emb = fluid.layers.reshape( + left_emb, shape=[-1, self.seq_len, self.bow_dim]) + right_emb = fluid.layers.reshape( + right_emb, shape=[-1, self.seq_len, self.bow_dim]) + + bow_left = fluid.layers.reduce_sum(left_emb, dim=1) + bow_right = fluid.layers.reduce_sum(right_emb, dim=1) + softsign_layer = SoftsignLayer() + left_soft = softsign_layer.ops(bow_left) + right_soft = softsign_layer.ops(bow_right) + + left_bow = self.bow_layer(left_soft) + right_bow = self.bow_layer(right_soft) + cos_sim_layer = CosSimLayer() + pred = cos_sim_layer.ops(left_bow, right_bow) + return left_bow, pred + + # TODO(huihuangzheng): uncomment the following return statements after + # we fix it. + # + # matching layer + #if self.task_mode == "pairwise": + # left_bow = self.bow_layer(left_soft) + # right_bow = self.bow_layer(right_soft) + # cos_sim_layer = CosSimLayer() + # pred = cos_sim_layer.ops(left_bow, right_bow) + # return left_bow, pred + #else: + # concat_layer = ConcatLayer(1) + # concat = concat_layer.ops([left_soft, right_soft]) + # concat_fc = self.bow_layer_po(concat) + # pred = self.softmax_layer(concat_fc) + # return left_soft, pred diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_simnet.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_simnet.py new file mode 100644 index 0000000000..94b9bb86be --- /dev/null +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_simnet.py @@ -0,0 +1,174 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import numpy as np +import paddle +import paddle.fluid as fluid +import random +import unittest + +from paddle.fluid.dygraph import ProgramTranslator +from simnet_dygraph_model import BOW, HingeLoss + +SEED = 102 +random.seed(SEED) + + +def create_conf_dict(): + conf_dict = {} + conf_dict["task_mode"] = "train" + conf_dict["net"] = {"emb_dim": 128, "bow_dim": 128, "hidden_dim": 128} + conf_dict["loss"] = {"margin": 0.1} + return conf_dict + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--batch_size", + type=int, + default=32, + help="Total examples' number in batch for training.") + parser.add_argument( + "--seq_len", type=int, default=32, help="The length of each sentence.") + parser.add_argument( + "--epoch", type=int, default=1, help="The number of training epoch.") + parser.add_argument( + "--fake_sample_size", + type=int, + default=128, + help="The number of samples of fake data.") + args = parser.parse_args([]) + return args + + +args = parse_args() + + +def fake_vocabulary(): + vocab = {} + vocab[""] = 0 + for i in range(26): + c = chr(ord('a') + i) + vocab[c] = i + 1 + return vocab + + +vocab = fake_vocabulary() + + +class FakeReaderProcessor(object): + def __init__(self, args, vocab): + self.vocab = vocab + self.seq_len = args.seq_len + self.sample_size = args.fake_sample_size + self.data_samples = [] + for i in range(self.sample_size): + query = [random.randint(0, 26) for i in range(self.seq_len)] + pos_title = query[:] + neg_title = [26 - q for q in query] + self.data_samples.append( + np.array([query, pos_title, neg_title]).astype(np.int64)) + + def get_reader(self, mode, epoch=0): + def reader_with_pairwise(): + if mode == "train": + for i in range(self.sample_size): + yield self.data_samples[i] + + return reader_with_pairwise + + +simnet_process = FakeReaderProcessor(args, vocab) + + +def train(conf_dict, to_static): + """ + train process + """ + program_translator = ProgramTranslator() + program_translator.enable(to_static) + + # Get device + if fluid.is_compiled_with_cuda(): + place = fluid.CUDAPlace(0) + else: + place = fluid.CPUPlace() + + with fluid.dygraph.guard(place): + fluid.default_startup_program().random_seed = SEED + fluid.default_main_program().random_seed = SEED + + conf_dict['dict_size'] = len(vocab) + conf_dict['seq_len'] = args.seq_len + + net = BOW(conf_dict) + loss = HingeLoss(conf_dict) + optimizer = fluid.optimizer.AdamOptimizer( + learning_rate=0.001, + beta1=0.9, + beta2=0.999, + epsilon=1e-08, + parameter_list=net.parameters()) + + metric = fluid.metrics.Auc(name="auc") + + global_step = 0 + losses = [] + + train_loader = fluid.io.DataLoader.from_generator( + capacity=16, + return_list=True, + iterable=True, + use_double_buffer=True) + get_train_examples = simnet_process.get_reader( + "train", epoch=args.epoch) + train_loader.set_sample_list_generator( + paddle.batch( + get_train_examples, batch_size=args.batch_size), place) + + for left, pos_right, neg_right in train_loader(): + left = fluid.layers.reshape(left, shape=[-1, 1]) + pos_right = fluid.layers.reshape(pos_right, shape=[-1, 1]) + neg_right = fluid.layers.reshape(neg_right, shape=[-1, 1]) + net.train() + global_step += 1 + left_feat, pos_score = net(left, pos_right) + pred = pos_score + _, neg_score = net(left, neg_right) + avg_cost = loss.compute(pos_score, neg_score) + #avg_cost = loss.compute(pos_score, pos_score) + losses.append(np.mean(avg_cost.numpy())) + avg_cost.backward() + optimizer.minimize(avg_cost) + net.clear_gradients() + return losses + + +class TestSimnet(unittest.TestCase): + def test_dygraph_static_same_loss(self): + if fluid.is_compiled_with_cuda(): + fluid.set_flags({"FLAGS_cudnn_deterministic": True}) + conf_dict = create_conf_dict() + dygraph_loss = train(conf_dict, to_static=False) + static_loss = train(conf_dict, to_static=True) + + self.assertEqual(len(dygraph_loss), len(static_loss)) + for i in range(len(dygraph_loss)): + self.assertAlmostEqual(dygraph_loss[i], static_loss[i]) + + +if __name__ == '__main__': + unittest.main() -- GitLab