未验证 提交 7d840e2e 编写于 作者: H huangxu96 提交者: GitHub

for supporting paddle rc1 (#488)

* for supporting paddle rc1

* add resnet_amp.py file for resent50 amp training.

* change yaml file name
上级 8d16077e
mode: 'train'
ARCHITECTURE:
name: 'ResNet50'
name: 'ResNet50_amp'
pretrained_model: ""
model_save_dir: "./output/"
......@@ -11,17 +11,16 @@ validate: True
valid_interval: 1
epochs: 120
topk: 5
image_shape: [3, 224, 224]
is_distributed: True
is_distributed: False
use_gpu: True
# mixed precision training
# mixed precision training related config
use_amp: True
use_pure_fp16: False
multi_precision: False
scale_loss: 128.0
use_dynamic_loss_scaling: True
data_format: "NCHW"
image_shape: [3, 224, 224]
data_format: "NHWC"
image_shape: [4, 224, 224]
use_dali : True
use_mix: False
ls_epsilon: -1
......@@ -42,7 +41,7 @@ OPTIMIZER:
factor: 0.000100
TRAIN:
batch_size: 256
batch_size: 128
num_workers: 4
file_list: "./dataset/ILSVRC2012/train_list.txt"
data_dir: "./dataset/ILSVRC2012/"
......
......@@ -13,6 +13,7 @@
# limitations under the License.
from .resnet import ResNet18, ResNet34, ResNet50, ResNet101, ResNet152
from .resnet_amp import ResNet50_amp
from .resnet_vc import ResNet18_vc, ResNet34_vc, ResNet50_vc, ResNet101_vc, ResNet152_vc
from .resnet_vd import ResNet18_vd, ResNet34_vd, ResNet50_vd, ResNet101_vd, ResNet152_vd, ResNet200_vd
from .resnext import ResNeXt50_32x4d, ResNeXt50_64x4d, ResNeXt101_32x4d, ResNeXt101_64x4d, ResNeXt152_32x4d, ResNeXt152_64x4d
......
from paddle.fluid.dygraph import layers
from paddle.fluid import core
from paddle.fluid.initializer import Constant
from paddle.fluid.param_attr import ParamAttr
from paddle.fluid.framework import in_dygraph_mode
from paddle.fluid.data_feeder import check_variable_and_dtype
import os
class BatchNorm(layers.Layer):
r"""
:alias_main: paddle.nn.BatchNorm
:alias: paddle.nn.BatchNorm,paddle.nn.layer.BatchNorm,paddle.nn.layer.norm.BatchNorm
:old_api: paddle.fluid.dygraph.BatchNorm
This interface is used to construct a callable object of the ``BatchNorm`` class.
For more details, refer to code examples.
It implements the function of the Batch Normalization Layer and can be used
as a normalizer function for conv2d and fully connected operations.
The data is normalized by the mean and variance of the channel based on the current batch data.
Refer to `Batch Normalization: Accelerating Deep Network Training by Reducing
Internal Covariate Shift <https://arxiv.org/pdf/1502.03167.pdf>`_
for more details.
When use_global_stats = False, the :math:`\\mu_{\\beta}`
and :math:`\\sigma_{\\beta}^{2}` are the statistics of one mini-batch.
Calculated as follows:
.. math::
\\mu_{\\beta} &\\gets \\frac{1}{m} \\sum_{i=1}^{m} x_i \\qquad &//\\
\ mini-batch\ mean \\\\
\\sigma_{\\beta}^{2} &\\gets \\frac{1}{m} \\sum_{i=1}^{m}(x_i - \\
\\mu_{\\beta})^2 \\qquad &//\ mini-batch\ variance \\\\
- :math:`x` : mini-batch data
- :math:`m` : the size of the mini-batch data
When use_global_stats = True, the :math:`\\mu_{\\beta}`
and :math:`\\sigma_{\\beta}^{2}` are not the statistics of one mini-batch.
They are global or running statistics (moving_mean and moving_variance). It usually got from the
pre-trained model. Calculated as follows:
.. math::
moving\_mean = moving\_mean * momentum + \mu_{\beta} * (1. - momentum) \quad &// global mean \\
moving\_variance = moving\_variance * momentum + \sigma_{\beta}^{2} * (1. - momentum) \quad &// global variance \\
The normalization function formula is as follows:
.. math::
\\hat{x_i} &\\gets \\frac{x_i - \\mu_\\beta} {\\sqrt{\\
\\sigma_{\\beta}^{2} + \\epsilon}} \\qquad &//\ normalize \\\\
y_i &\\gets \\gamma \\hat{x_i} + \\beta \\qquad &//\ scale\ and\ shift
- :math:`\\epsilon` : add a smaller value to the variance to prevent division by zero
- :math:`\\gamma` : trainable proportional parameter
- :math:`\\beta` : trainable deviation parameter
Parameters:
num_channels(int): Indicate the number of channels of the input ``Tensor``.
act(str, optional): Activation to be applied to the output of batch normalization. Default: None.
is_test (bool, optional): A flag indicating whether it is in test phrase or not.
This flag only has effect on static graph mode. For dygraph mode, please use ``eval()``.
Default: False.
momentum(float, optional): The value used for the moving_mean and moving_var computation. Default: 0.9.
epsilon(float, optional): The small value added to the variance to prevent division by zero. Default: 1e-5.
param_attr(ParamAttr, optional): The parameter attribute for Parameter `scale`
of batch_norm. If it is set to None or one attribute of ParamAttr, batch_norm
will create ParamAttr as param_attr. If the Initializer of the param_attr
is not set, the parameter is initialized with Xavier. Default: None.
bias_attr(ParamAttr, optional): The parameter attribute for the bias of batch_norm.
If it is set to None or one attribute of ParamAttr, batch_norm
will create ParamAttr as bias_attr. If the Initializer of the bias_attr
is not set, the bias is initialized zero. Default: None.
dtype(str, optional): Indicate the data type of the input ``Tensor``,
which can be float32 or float64. Default: float32.
data_layout(str, optional): Specify the input data format, the data format can be "NCHW" or "NHWC". Default: NCHW.
in_place(bool, optional): Make the input and output of batch norm reuse memory. Default: False.
moving_mean_name(str, optional): The name of moving_mean which store the global Mean. Default: None.
moving_variance_name(str, optional): The name of the moving_variance which store the global Variance. Default: None.
do_model_average_for_mean_and_var(bool, optional): Whether parameter mean and variance should do model
average when model average is enabled. Default: True.
use_global_stats(bool, optional): Whether to use global mean and
variance. In inference or test mode, set use_global_stats to true
or is_test to true, and the behavior is equivalent.
In train mode, when setting use_global_stats True, the global mean
and variance are also used during train period. Default: False.
trainable_statistics(bool, optional): Whether to calculate mean and var in eval mode. In eval mode, when
setting trainable_statistics True, mean and variance will be calculated by current batch statistics.
Default: False.
Returns:
None
Examples:
.. code-block:: python
import paddle.fluid as fluid
from paddle.fluid.dygraph.base import to_variable
import numpy as np
x = np.random.random(size=(3, 10, 3, 7)).astype('float32')
with fluid.dygraph.guard():
x = to_variable(x)
batch_norm = fluid.BatchNorm(10)
hidden1 = batch_norm(x)
"""
def __init__(self,
num_channels,
act=None,
is_test=False,
momentum=0.9,
epsilon=1e-05,
param_attr=None,
bias_attr=None,
dtype='float32',
data_layout='NCHW',
in_place=False,
moving_mean_name=None,
moving_variance_name=None,
do_model_average_for_mean_and_var=True,
use_global_stats=False,
trainable_statistics=False):
super(BatchNorm, self).__init__()
self._param_attr = param_attr
self._bias_attr = bias_attr
self._act = act
self._use_mkldnn = core.globals()["FLAGS_use_mkldnn"]
assert bias_attr is not False, "bias_attr should not be False in batch_norm."
if dtype == "float16":
self._dtype = "float32"
else:
self._dtype = dtype
param_shape = [num_channels]
# create parameter
self.weight = self.create_parameter(
attr=self._param_attr,
shape=param_shape,
dtype=self._dtype,
default_initializer=Constant(1.0))
self.weight.stop_gradient = use_global_stats and self._param_attr.learning_rate == 0.
self.bias = self.create_parameter(
attr=self._bias_attr,
shape=param_shape,
dtype=self._dtype,
is_bias=True)
self.bias.stop_gradient = use_global_stats and self._param_attr.learning_rate == 0.
self._mean = self.create_parameter(
attr=ParamAttr(
name=moving_mean_name,
initializer=Constant(0.0),
trainable=False,
do_model_average=do_model_average_for_mean_and_var),
shape=param_shape,
dtype=self._dtype)
self._mean.stop_gradient = True
self._variance = self.create_parameter(
attr=ParamAttr(
name=moving_variance_name,
initializer=Constant(1.0),
trainable=False,
do_model_average=do_model_average_for_mean_and_var),
shape=param_shape,
dtype=self._dtype)
self._variance.stop_gradient = True
self._has_reserve_space = False
if data_layout == 'NHWC':
flag = os.environ.get('FLAGS_cudnn_batchnorm_spatial_persistent')
if flag is not None and flag.lower() in ['true', '1']:
self._has_reserve_space = True
self._in_place = in_place
self._data_layout = data_layout
self._momentum = momentum
self._epsilon = epsilon
self._is_test = is_test
self._fuse_with_relu = False
self._use_global_stats = use_global_stats
self._trainable_statistics = trainable_statistics
def forward(self, input):
# create output
# mean and mean_out share the same memory
mean_out = self._mean
# variance and variance out share the same memory
variance_out = self._variance
if in_dygraph_mode():
attrs = ("momentum", self._momentum, "epsilon", self._epsilon,
"is_test", not self.training, "data_layout",
self._data_layout, "use_mkldnn", self._use_mkldnn,
"fuse_with_relu", self._fuse_with_relu, "use_global_stats",
self._use_global_stats, 'trainable_statistics',
self._trainable_statistics)
batch_norm_out, _, _, _, _, _ = core.ops.batch_norm(
input, self.weight, self.bias, self._mean, self._variance,
mean_out, variance_out, *attrs)
return dygraph_utils._append_activation_in_dygraph(
batch_norm_out, act=self._act, use_mkldnn=self._use_mkldnn)
check_variable_and_dtype(input, 'input',
['float16', 'float32', 'float64'], 'BatchNorm')
attrs = {
"momentum": self._momentum,
"epsilon": self._epsilon,
"is_test": self._is_test,
"data_layout": self._data_layout,
"use_mkldnn": False,
"fuse_with_relu": self._fuse_with_relu,
"use_global_stats": self._use_global_stats,
"trainable_statistics": self._trainable_statistics,
}
inputs = {
"X": [input],
"Scale": [self.weight],
"Bias": [self.bias],
"Mean": [self._mean],
"Variance": [self._variance]
}
saved_mean = self._helper.create_variable_for_type_inference(
dtype=self._dtype, stop_gradient=True)
saved_variance = self._helper.create_variable_for_type_inference(
dtype=self._dtype, stop_gradient=True)
reserve_space = None
if self._has_reserve_space:
reserve_space = self._helper.create_variable_for_type_inference(
dtype=core.VarDesc.VarType.FP16, stop_gradient=True)
batch_norm_out = input if self._in_place else self._helper.create_variable_for_type_inference(
self._dtype)
outputs = {
"Y": [batch_norm_out],
"MeanOut": [mean_out],
"VarianceOut": [variance_out],
"SavedMean": [saved_mean],
"SavedVariance": [saved_variance]
}
if reserve_space is not None:
outputs["ReserveSpace"] = reserve_space
self._helper.append_op(
type="batch_norm", inputs=inputs, outputs=outputs, attrs=attrs)
# Currently, we don't support inplace in dygraph mode
return self._helper.append_activation(batch_norm_out, self._act)
\ No newline at end of file
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import paddle
from paddle import ParamAttr
import paddle.nn as nn
import paddle.nn.functional as F
from paddle.nn import Conv2D, Linear, Dropout
from paddle.nn import AdaptiveAvgPool2D, MaxPool2D, AvgPool2D
from paddle.nn.initializer import Uniform
from .nn import BatchNorm
import math
__all__ = ["ResNet50_amp"]
class ConvBNLayer(nn.Layer):
def __init__(self,
num_channels,
num_filters,
filter_size,
stride=1,
groups=1,
act=None,
name=None,
data_format="NCHW"):
super(ConvBNLayer, self).__init__()
self._conv = Conv2D(
in_channels=num_channels,
out_channels=num_filters,
kernel_size=filter_size,
stride=stride,
padding=(filter_size - 1) // 2,
groups=groups,
weight_attr=ParamAttr(name=name + "_weights"),
bias_attr=False,
data_format=data_format)
if name == "conv1":
bn_name = "bn_" + name
else:
bn_name = "bn" + name[3:]
self._batch_norm = BatchNorm(
num_filters,
act=act,
param_attr=ParamAttr(name=bn_name + "_scale"),
bias_attr=ParamAttr(bn_name + "_offset"),
moving_mean_name=bn_name + "_mean",
moving_variance_name=bn_name + "_variance",
data_layout=data_format)
def forward(self, inputs):
y = self._conv(inputs)
y = self._batch_norm(y)
return y
class BottleneckBlock(nn.Layer):
def __init__(self,
num_channels,
num_filters,
stride,
shortcut=True,
name=None,
data_format="NCHW"):
super(BottleneckBlock, self).__init__()
self.conv0 = ConvBNLayer(
num_channels=num_channels,
num_filters=num_filters,
filter_size=1,
act="relu",
name=name + "_branch2a",
data_format=data_format)
self.conv1 = ConvBNLayer(
num_channels=num_filters,
num_filters=num_filters,
filter_size=3,
stride=stride,
act="relu",
name=name + "_branch2b",
data_format=data_format)
self.conv2 = ConvBNLayer(
num_channels=num_filters,
num_filters=num_filters * 4,
filter_size=1,
act=None,
name=name + "_branch2c",
data_format=data_format)
if not shortcut:
self.short = ConvBNLayer(
num_channels=num_channels,
num_filters=num_filters * 4,
filter_size=1,
stride=stride,
name=name + "_branch1",
data_format=data_format)
self.shortcut = shortcut
self._num_channels_out = num_filters * 4
def forward(self, inputs):
y = self.conv0(inputs)
conv1 = self.conv1(y)
conv2 = self.conv2(conv1)
if self.shortcut:
short = inputs
else:
short = self.short(inputs)
y = paddle.add(x=short, y=conv2)
y = F.relu(y)
return y
class BasicBlock(nn.Layer):
def __init__(self,
num_channels,
num_filters,
stride,
shortcut=True,
name=None,
data_format="NCHW"):
super(BasicBlock, self).__init__()
self.stride = stride
self.conv0 = ConvBNLayer(
num_channels=num_channels,
num_filters=num_filters,
filter_size=3,
stride=stride,
act="relu",
name=name + "_branch2a",
data_format=data_format)
self.conv1 = ConvBNLayer(
num_channels=num_filters,
num_filters=num_filters,
filter_size=3,
act=None,
name=name + "_branch2b",
data_format=data_format)
if not shortcut:
self.short = ConvBNLayer(
num_channels=num_channels,
num_filters=num_filters,
filter_size=1,
stride=stride,
name=name + "_branch1",
data_format=data_format)
self.shortcut = shortcut
def forward(self, inputs):
y = self.conv0(inputs)
conv1 = self.conv1(y)
if self.shortcut:
short = inputs
else:
short = self.short(inputs)
y = paddle.add(x=short, y=conv1)
y = F.relu(y)
return y
class ResNet(nn.Layer):
def __init__(self, layers=50, class_dim=1000, input_image_channel=3, data_format="NCHW"):
super(ResNet, self).__init__()
self.layers = layers
self.data_format = data_format
self.input_image_channel = input_image_channel
supported_layers = [18, 34, 50, 101, 152]
assert layers in supported_layers, \
"supported layers are {} but input layer is {}".format(
supported_layers, layers)
if layers == 18:
depth = [2, 2, 2, 2]
elif layers == 34 or layers == 50:
depth = [3, 4, 6, 3]
elif layers == 101:
depth = [3, 4, 23, 3]
elif layers == 152:
depth = [3, 8, 36, 3]
num_channels = [64, 256, 512,
1024] if layers >= 50 else [64, 64, 128, 256]
num_filters = [64, 128, 256, 512]
self.conv = ConvBNLayer(
num_channels=self.input_image_channel,
num_filters=64,
filter_size=7,
stride=2,
act="relu",
name="conv1",
data_format=self.data_format)
self.pool2d_max = MaxPool2D(
kernel_size=3,
stride=2,
padding=1,
data_format=self.data_format)
self.block_list = []
if layers >= 50:
for block in range(len(depth)):
shortcut = False
for i in range(depth[block]):
if layers in [101, 152] and block == 2:
if i == 0:
conv_name = "res" + str(block + 2) + "a"
else:
conv_name = "res" + str(block + 2) + "b" + str(i)
else:
conv_name = "res" + str(block + 2) + chr(97 + i)
bottleneck_block = self.add_sublayer(
conv_name,
BottleneckBlock(
num_channels=num_channels[block]
if i == 0 else num_filters[block] * 4,
num_filters=num_filters[block],
stride=2 if i == 0 and block != 0 else 1,
shortcut=shortcut,
name=conv_name,
data_format=self.data_format))
self.block_list.append(bottleneck_block)
shortcut = True
else:
for block in range(len(depth)):
shortcut = False
for i in range(depth[block]):
conv_name = "res" + str(block + 2) + chr(97 + i)
basic_block = self.add_sublayer(
conv_name,
BasicBlock(
num_channels=num_channels[block]
if i == 0 else num_filters[block],
num_filters=num_filters[block],
stride=2 if i == 0 and block != 0 else 1,
shortcut=shortcut,
name=conv_name,
data_format=self.data_format))
self.block_list.append(basic_block)
shortcut = True
self.pool2d_avg = AdaptiveAvgPool2D(1, data_format=self.data_format)
self.pool2d_avg_channels = num_channels[-1] * 2
stdv = 1.0 / math.sqrt(self.pool2d_avg_channels * 1.0)
self.out = Linear(
self.pool2d_avg_channels,
class_dim,
weight_attr=ParamAttr(
initializer=Uniform(-stdv, stdv), name="fc_0.w_0"),
bias_attr=ParamAttr(name="fc_0.b_0"))
def forward(self, inputs):
y = self.conv(inputs)
y = self.pool2d_max(y)
for block in self.block_list:
y = block(y)
y = self.pool2d_avg(y)
y = paddle.reshape(y, shape=[-1, self.pool2d_avg_channels])
y = self.out(y)
return y
def ResNet50_amp(**args):
model = ResNet(layers=50, **args)
return model
......@@ -82,17 +82,12 @@ class Momentum(object):
self.momentum = momentum
self.parameter_list = parameter_list
self.regularization = regularization
self.multi_precision = config.get('multi_precision', False)
self.rescale_grad = (1.0 / (config['TRAIN']['batch_size'] / len(fluid.cuda_places()))
if config.get('use_pure_fp16', False) else 1.0)
def __call__(self):
opt = fluid.contrib.optimizer.Momentum(
learning_rate=self.learning_rate,
momentum=self.momentum,
regularization=self.regularization,
multi_precision=self.multi_precision,
rescale_grad=self.rescale_grad)
regularization=self.regularization)
return opt
......
......@@ -26,7 +26,7 @@ from optimizer import OptimizerBuilder
import paddle
import paddle.nn.functional as F
from paddle import fluid
from paddle.fluid.contrib.mixed_precision.fp16_utils import cast_model_to_fp16
from ppcls.optimizer.learning_rate import LearningRateBuilder
from ppcls.modeling import architectures
......@@ -83,7 +83,6 @@ def create_model(architecture, image, classes_num, config, is_train):
Returns:
out(variable): model output variable
"""
use_pure_fp16 = config.get("use_pure_fp16", False)
name = architecture["name"]
params = architecture.get("params", {})
......@@ -101,15 +100,10 @@ def create_model(architecture, image, classes_num, config, is_train):
params['is_test'] = not is_train
model = architectures.__dict__[name](class_dim=classes_num, **params)
if use_pure_fp16 and not config.get("use_dali", False):
image = image.astype('float16')
if data_format == "NHWC":
image = paddle.tensor.transpose(image, [0, 2, 3, 1])
image.stop_gradient = True
out = model(image)
if config.get("use_pure_fp16", False):
cast_model_to_fp16(paddle.static.default_main_program())
out = out.astype('float32')
return out
......@@ -119,8 +113,7 @@ def create_loss(out,
classes_num=1000,
epsilon=None,
use_mix=False,
use_distillation=False,
use_pure_fp16=False):
use_distillation=False):
"""
Create a loss for optimization, such as:
1. CrossEnotry loss
......@@ -137,7 +130,6 @@ def create_loss(out,
classes_num(int): num of classes
epsilon(float): parameter for label smoothing, 0.0 <= epsilon <= 1.0
use_mix(bool): whether to use mix(include mixup, cutmix, fmix)
use_pure_fp16(bool): whether to use pure fp16 data as training parameter
Returns:
loss(variable): loss variable
......@@ -162,10 +154,10 @@ def create_loss(out,
if use_mix:
loss = MixCELoss(class_dim=classes_num, epsilon=epsilon)
return loss(out, feed_y_a, feed_y_b, feed_lam, use_pure_fp16)
return loss(out, feed_y_a, feed_y_b, feed_lam)
else:
loss = CELoss(class_dim=classes_num, epsilon=epsilon)
return loss(out, target, use_pure_fp16)
return loss(out, target)
def create_metric(out,
......@@ -239,9 +231,8 @@ def create_fetchs(out,
fetchs(dict): dict of model outputs(included loss and measures)
"""
fetchs = OrderedDict()
use_pure_fp16 = config.get("use_pure_fp16", False)
loss = create_loss(out, feeds, architecture, classes_num, epsilon, use_mix,
use_distillation, use_pure_fp16)
use_distillation)
fetchs['loss'] = (loss, AverageMeter('loss', '7.4f', need_avg=True))
if not use_mix:
metric = create_metric(out, feeds, architecture, topk, classes_num,
......@@ -403,11 +394,9 @@ def compile(config, program, loss_name=None, share_prog=None):
exec_strategy = paddle.static.ExecutionStrategy()
exec_strategy.num_threads = 1
exec_strategy.num_iteration_per_drop_scope = 10000 if config.get(
'use_pure_fp16', False) else 10
exec_strategy.num_iteration_per_drop_scope = 10
fuse_op = config.get('use_amp', False) or config.get('use_pure_fp16',
False)
fuse_op = config.get('use_amp', False)
fuse_bn_act_ops = config.get('fuse_bn_act_ops', fuse_op)
fuse_elewise_add_act_ops = config.get('fuse_elewise_add_act_ops', fuse_op)
fuse_bn_add_act_ops = config.get('fuse_bn_add_act_ops', fuse_op)
......
......@@ -27,7 +27,6 @@ from sys import version_info
import paddle
import paddle.fluid as fluid
from paddle.fluid.contrib.mixed_precision.fp16_utils import cast_parameters_to_fp16
from paddle.distributed import fleet
from ppcls.data import Reader
......@@ -68,8 +67,7 @@ def main(args):
use_gpu = config.get("use_gpu", True)
# amp related config
use_amp = config.get('use_amp', False)
use_pure_fp16 = config.get('use_pure_fp16', False)
if use_amp or use_pure_fp16:
if use_amp:
AMP_RELATED_FLAGS_SETTING = {
'FLAGS_cudnn_exhaustive_search': 1,
'FLAGS_conv_workspace_size_limit': 4000,
......@@ -119,8 +117,6 @@ def main(args):
exe = paddle.static.Executor(place)
# Parameter initialization
exe.run(startup_prog)
if config.get("use_pure_fp16", False):
cast_parameters_to_fp16(place, train_prog, fluid.global_scope())
# load pretrained models or checkpoints
init_model(config, train_prog, exe)
......@@ -141,11 +137,11 @@ def main(args):
else:
assert use_gpu is True, "DALI only support gpu, please set use_gpu to True!"
import dali
train_dataloader = dali.train(config)
os.environ["FLAGS_fraction_of_gpu_memory_to_use"] = "0.8"
train_dataloader = dali.train(config)
if config.validate and paddle.distributed.get_rank() == 0:
valid_dataloader = dali.val(config)
compiled_valid_prog = program.compile(config, valid_prog)
vdl_writer = None
if args.vdl_dir:
if version_info.major == 2:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册