未验证 提交 dfed4a63 编写于 作者: H Haohongxiang 提交者: GitHub

support offload in sharding stage2 (#37904)

* merge latest develop branch

* fix bugs

* update

* fix bugs for unittest

* modify for less use of gpu mem

* fix bugs of using _reset_grad_inplace_version

* update

* update

* modify for CI-Coverage

* retrick all CIs
上级 a9f81534
......@@ -27,11 +27,13 @@ from collections import OrderedDict
import paddle
import paddle.fluid as fluid
from paddle import framework
from paddle.fluid import core
import paddle.distributed as dist
from paddle.optimizer import Optimizer
from paddle.fluid.clip import ClipGradByGlobalNorm
from ...utils.internal_storage import ParamStorage
from ...meta_parallel.sharding.sharding_utils import Type
from ...meta_parallel.sharding.sharding_utils import Type, device_guard, ShardingClipGrad
# CUDA alignment 256 bytes
alignment = {"gpu": 256, }
......@@ -99,16 +101,41 @@ class ShardingOptimizerStage2(Optimizer):
self.broadcast_fp16 = broadcast_fp16
self.param_storages = {} # {dtype: {rank: InternalStorage}}
if isinstance(self._optim._grad_clip, ClipGradByGlobalNorm):
logging.warning(
"While using ClipGradByGlobalNorm in ShardingOptimizer, the grad clip of original optimizer will be changed."
)
self._optim._grad_clip = ShardingClipGrad(self._optim._grad_clip,
group,
paddle.get_device())
if offload:
assert self._pfp16, "Only support offload strategy while using \'Adam\', \'AdamW\' and \'Momentum\' optimizer with AMP/Pure FP16"
self.offload = offload # Using for offload
self.offload_device = "cpu"
self._master_params = {}
# Update optimizer parameters and adjust parameter storage and use according to rank.
self.update_opt_status()
def _generate_master_params(self, trainable_params):
for param in trainable_params:
if param.dtype == Type.fp16.value:
self._optim._master_weights[param.name] = paddle.cast(
param, Type.fp32.value)
if self.offload:
for param in trainable_params:
if param.name not in self._master_params.keys():
self._master_params[param.name] = core.VarBase(
name=param.name,
value=param.cast(dtype=Type.fp32.value).numpy(),
place=core.CPUPlace(),
stop_gradient=param.stop_gradient)
self._optim._master_weights = self._master_params
else:
for param in trainable_params:
if param.dtype == Type.fp16.value:
self._optim._master_weights[param.name] = paddle.cast(
param, Type.fp32.value)
def update_opt_status(self):
"""Update optimizer status and parameter storage information, and special functions to be developed.
......@@ -243,22 +270,43 @@ class ShardingOptimizerStage2(Optimizer):
A wrapper for Optimizer's step function to finish the update operation of the optimizer.
"""
# Synchronize optimizer parameters for the current rank
if len(self.dtype_rank_params.keys(
)) == 1 and Type.fp32.value in self.dtype_rank_params.keys():
self._optim._parameter_list = self.dtype_rank_params[
Type.fp32.value][self.rank]
elif len(self.dtype_rank_params.keys(
)) == 1 and Type.fp16.value in self.dtype_rank_params.keys():
self._optim._parameter_list = self.dtype_rank_params[
Type.fp16.value][self.rank]
if self.offload:
self._optim._parameter_list = [
param for name, param in self._master_params.items()
]
else:
self._optim._parameter_list = self.dtype_rank_params[
Type.fp16.value][self.rank] + self.dtype_rank_params[
# Synchronize optimizer parameters for the current rank
if len(self.dtype_rank_params.keys(
)) == 1 and Type.fp32.value in self.dtype_rank_params.keys():
self._optim._parameter_list = self.dtype_rank_params[
Type.fp32.value][self.rank]
elif len(self.dtype_rank_params.keys(
)) == 1 and Type.fp16.value in self.dtype_rank_params.keys():
self._optim._parameter_list = self.dtype_rank_params[
Type.fp16.value][self.rank]
else:
self._optim._parameter_list = self.dtype_rank_params[
Type.fp16.value][self.rank] + self.dtype_rank_params[
Type.fp32.value][self.rank]
# Run the optimizer of the current rank step
self._optim.step()
if self.offload:
with device_guard(self.rank, self.offload_device):
self._optim.step()
for param in self._optim._parameter_list:
self._master_params[param.name].set_value(param)
dev_id = 0 if paddle.get_device() == "cpu" else int(
paddle.get_device().split(":")[1])
for param in self._local_params:
if param.name in self._master_params.keys():
param.set_value(self._master_params[param.name].cuda(dev_id)
.cast(dtype=param.dtype))
self._master_params[param.name].clear_gradient(False)
else:
self._optim.step()
# Synchronize all the updated shards in between the ranks
self._broadcast_params()
......
......@@ -112,6 +112,18 @@ class ShardingStage2(nn.Layer):
self._has_grad_storage = []
self._grad_storage_list = []
# offload
# TODO(haohongxiang): Now it's not supported for multi-optimizers using Offload strategy
self._offload_optims = list(
filter(lambda optim: optim.offload, self._sharding_optimizers))
if len(self._offload_optims) > 0:
assert len(
self._sharding_optimizers
) == 1, "Only support offload strategy for single optimizer"
self._offload = self._sharding_optimizers[0].offload
self._offload_device = "cpu"
# Set backward pass hooks
self._bw_hooks = []
......@@ -156,7 +168,8 @@ class ShardingStage2(nn.Layer):
# Release grad storages
for dtype in self._grad_storages.keys():
if self._rank in self._grad_storages[dtype].keys():
self._grad_storages[dtype][self._rank].buffer.zero_()
if not self._offload:
self._grad_storages[dtype][self._rank].buffer.zero_()
# Release params
for param in self._trainable_params:
......@@ -167,17 +180,24 @@ class ShardingStage2(nn.Layer):
"""
Before the gradient accumulation, scale the gradient.
"""
# Scale grad storages
for dtype in self._grad_storages.keys():
if self._rank in self._grad_storages[dtype].keys():
self._grad_storages[dtype][self._rank].buffer.scale_(
scale=self._world_size_scaling)
# Scale params
for param in self._trainable_params:
if param.name in self._param_grads and param.grad is not None:
param.grad.scale_(scale=self._world_size_scaling)
param._reset_grad_inplace_version(True)
if self._offload:
for param in self._trainable_params:
if param.name in self._sharding_optimizers[
0]._master_params.keys():
self._sharding_optimizers[0]._master_params[
param.name].grad.scale_(scale=self._world_size_scaling)
else:
# Scale grad storages
for dtype in self._grad_storages.keys():
if self._rank in self._grad_storages[dtype].keys():
self._grad_storages[dtype][self._rank].buffer.scale_(
scale=self._world_size_scaling)
# Scale params
for param in self._trainable_params:
if param.name in self._param_grads and param.grad is not None:
param.grad.scale_(scale=self._world_size_scaling)
param._reset_grad_inplace_version(True)
def _init_internal_storage(self, needs_fresh):
"""
......@@ -195,8 +215,14 @@ class ShardingStage2(nn.Layer):
"""
Synchronously or asynchronously convert the data type of the layer, the device is not supported now.
"""
assert isinstance(device, str), "Device must be type str"
assert device == self._default_device, "New devices are not supported, because of the optimizer state is not sync"
self._layer.to(device=device, dtype=dtype, blocking=blocking)
# Re-build the buckets, hooks, etc..
self._fresh_trainable()
def _fresh_trainable(self):
""" Whether to update training parameters. """
......@@ -283,12 +309,17 @@ class ShardingStage2(nn.Layer):
self._grad_reduced[index] = False
if not self._accumulate_grads:
param.grad.scale_(scale=self._world_size_scaling)
param._reset_grad_inplace_version(True)
param._reset_grad_inplace_version(True)
# Clear the gradient that does not belong to the current rank through the callback function
def cleanup():
if dst_rank != self._rank:
param.clear_gradient(False)
elif self._offload:
self._sharding_optimizers[0]._master_params[
param.name]._copy_gradient_from(param.grad.cpu(
).cast(dtype=Type.fp32.value))
param.clear_gradient(False)
# Synchronize the reduce parameter gradient
self._tasks_flow.append(
......@@ -339,6 +370,15 @@ class ShardingStage2(nn.Layer):
grad_storage.buffer.value().get_tensor()._clear(
)
elif self._offload:
grad_storage.to(device=self._offload_device)
for param in grad_storage._params:
self._sharding_optimizers[0]._master_params[
param.name]._copy_gradient_from(
param.grad.cast(
dtype=Type.fp32.value))
grad_storage.buffer.value().get_tensor()._clear(
)
# Reduce the bucket
grad_storage.sent = True
......@@ -478,7 +518,7 @@ class ShardingStage2(nn.Layer):
# Rebuild fp16/fp32 grad storages
for dtype in self._grad_storages.keys():
for dst_rank, grad_storage in self._grad_storages[dtype].items():
if dst_rank != self._rank:
if self._offload or dst_rank != self._rank:
grad_storage.manumal_relase()
grad_storage.rebuild()
......
......@@ -17,10 +17,17 @@ import contextlib
from collections import abc
from enum import Enum
from math import inf
import numpy as np
from types import MethodType
import paddle
import paddle.distributed as dist
from paddle import _C_ops
from paddle.fluid import core
from paddle.fluid import layers
from paddle.fluid.dygraph import to_variable
from paddle.fluid.framework import dygraph_only
from paddle.fluid.dygraph import base as imperative_base
class Taskflow:
......@@ -41,6 +48,88 @@ class Type(Enum):
fp32 = paddle.float32
class ShardingClipGrad:
def __init__(self, clip, group, device):
self._clip = clip
self._group = group
self._device = device
@imperative_base.no_grad
def _dygraph_clip(self, params_grads):
params_and_grads = []
sum_square_fp16 = []
sum_square_fp32 = []
for p, g in params_grads:
if g is None or getattr(p, 'need_clip', True) is False:
continue
merge_grad = g
if g.type == core.VarDesc.VarType.SELECTED_ROWS:
merge_grad = layers.get_tensor_from_selected_rows(
layers.merge_selected_rows(g))
square = layers.square(merge_grad)
sum_square = layers.reduce_sum(square)
if p.dtype == paddle.float16:
sum_square_fp16.append(sum_square)
elif p.dtype == paddle.float32:
sum_square_fp32.append(sum_square)
# global norm of non-distributed FP16 params_and_grads
if len(sum_square_fp16) == 0:
global_norm_fp16 = paddle.to_tensor([0.], dtype=paddle.float32)
else:
global_norm_fp16 = layers.concat(sum_square_fp16)
global_norm_fp16 = layers.reduce_sum(global_norm_fp16)
global_norm_fp16 = paddle.cast(
global_norm_fp16, dtype=paddle.float32)
# global norm of non-distributed FP32 params_and_grads
global_norm_fp32 = layers.concat(sum_square_fp32) if len(
sum_square_fp32) != 0 else paddle.to_tensor(
[0.], dtype=paddle.float32)
global_norm_fp32 = layers.reduce_sum(global_norm_fp32)
global_norm_var = global_norm_fp16 + global_norm_fp32
# add all reduce to get global norm of distributed params_and_grads
dev_id = int(self._device.split(":")[1])
with device_guard(dev_id, "gpu"):
paddle.distributed.all_reduce(global_norm_var, group=self._group)
global_norm_var = layers.sqrt(global_norm_var)
max_global_norm = layers.fill_constant(
shape=[1], dtype=global_norm_var.dtype, value=self.clip_norm)
clip_var = layers.elementwise_div(
x=max_global_norm,
y=layers.elementwise_max(
x=global_norm_var, y=max_global_norm))
clip_var_fp16 = paddle.cast(clip_var, paddle.float16)
for p, g in params_grads:
if g is None:
continue
if getattr(p, 'need_clip', True) is False:
params_and_grads.append((p, g))
continue
if p.dtype == paddle.float16:
new_grad = layers.elementwise_mul(x=g, y=clip_var_fp16)
else:
new_grad = layers.elementwise_mul(x=g, y=clip_var)
params_and_grads.append((p, new_grad))
return params_and_grads
def __getattr__(self, item):
return getattr(self._clip, item)
def __call__(self, params_grads):
return self._dygraph_clip(params_grads)
@contextlib.contextmanager
def device_guard(dev_id, device="cpu"):
origin_device = paddle.device.get_device()
......@@ -52,3 +141,65 @@ def device_guard(dev_id, device="cpu"):
yield
finally:
paddle.set_device(origin_device)
@dygraph_only
def ShardingScaler(scaler, sharding_group):
def unscale_method(self, optimizer):
if not self._enable:
return
param_grads = []
param_grads_fp16 = []
param_grads_fp32 = []
if getattr(optimizer, '_param_groups', None) and isinstance(
optimizer._param_groups[0], dict):
for group in optimizer._param_groups:
for param in group['params']:
if param._grad_ivar() is not None:
param_grads.append(param._grad_ivar())
if param._grad_ivar(
).dtype == core.VarDesc.VarType.FP16:
param_grads_fp16.append(param._grad_ivar())
else:
param_grads_fp32.append(param._grad_ivar())
else:
param_grads = [
param._grad_ivar() for param in optimizer._parameter_list
if param._grad_ivar() is not None
]
param_grads_fp16 = [
param._grad_ivar() for param in optimizer._parameter_list
if (param._grad_ivar() is not None
) and (param._grad_ivar().dtype == core.VarDesc.VarType.FP16
)
]
param_grads_fp32 = [
param._grad_ivar() for param in optimizer._parameter_list
if (param._grad_ivar() is not None
) and (param._grad_ivar().dtype == core.VarDesc.VarType.FP32
)
]
temp_found_inf_fp16 = to_variable(np.array([0]).astype(np.bool))
temp_found_inf_fp32 = to_variable(np.array([0]).astype(np.bool))
if len(param_grads_fp16):
_C_ops.check_finite_and_unscale(param_grads_fp16, self._scale,
param_grads_fp16,
temp_found_inf_fp16)
if len(param_grads_fp32):
_C_ops.check_finite_and_unscale(param_grads_fp32, self._scale,
param_grads_fp32,
temp_found_inf_fp32)
self._found_inf = 1 if temp_found_inf_fp16 or temp_found_inf_fp32 else 0
is_found_inf = paddle.to_tensor([self._found_inf], dtype="int32")
paddle.distributed.all_reduce(
is_found_inf,
op=paddle.distributed.ReduceOp.MAX,
group=sharding_group)
self._found_inf = is_found_inf.numpy()[0]
scaler._unscale = MethodType(unscale_method, scaler)
return scaler
......@@ -50,6 +50,29 @@ class InternalStorage:
else:
self.buffer = paddle.zeros(size, dtype=dtype)
def to(self, device, dtype=None, keep_alignment=True):
"""
Move the underlying buffer
"""
assert self.buffer is not None, "Cannot move a collapsed bucket, please rebuild it"
assert (dtype == Type.fp32.value or
Type.fp16.value), "Conversion type is not supported now"
dev_id = 0 if paddle.get_device() == "cpu" else int(paddle.get_device()
.split(":")[1])
if self._device != device:
tmp_buffer = self.buffer.cuda(
dev_id) if device == "gpu" else self.buffer.cpu()
for param in self._params:
param.clear_gradient(False)
param._gradient_set_empty(False)
self.buffer.value().get_tensor()._clear()
self.buffer = tmp_buffer
if dtype is not None:
self.buffer = self.buffer.cast(dtype=dtype)
class ParamStorage(InternalStorage):
"""
......@@ -60,6 +83,16 @@ class ParamStorage(InternalStorage):
super().__init__(size, dtype, device, convert_cpu=True)
self.param2align = None
def to(self, device, dtype=None, keep_alignment=True):
"""
Move the underlying buffer
"""
super().to(device, dtype)
if keep_alignment:
self._array_params()
@fluid.dygraph.no_grad
def add_rank_params(self, trainable_params, param2align):
"""
......@@ -78,7 +111,7 @@ class ParamStorage(InternalStorage):
p_shape = self._add_param_as_view(param, param2align[param.name])
cpu_param_shape.append(p_shape)
# buffer covert from cpu to cuda
# buffer convert from cpu to cuda
dev_id = int(paddle.get_device().split(":")[1])
self.buffer = self.buffer.cuda(dev_id)
self._fill = 0
......@@ -109,7 +142,8 @@ class ParamStorage(InternalStorage):
param.stop_gradient = origin_state
# Copy the current param value
dev_id = int(paddle.get_device().split(":")[1])
dev_id = 0 if paddle.get_device() == "cpu" else int(paddle.get_device()
.split(":")[1])
with device_guard(dev_id, "cpu"):
tmp_var = core.VarBase(tensor=self.buffer._slice(self._fill,
var_end))
......@@ -134,6 +168,18 @@ class ParamStorage(InternalStorage):
self._fill = offset
@fluid.dygraph.no_grad
def _array_params(self):
"""
Given the parameters which have been registered previously, rebuild the whole InternalStorage.
"""
assert len(self._params) > 0
assert self.param2align is not None
self._fill = 0
for p in self._params:
self._convert_buffer(p, p.shape, self.param2align[p.name]) # modify
class GradStorage(InternalStorage):
"""
......@@ -171,6 +217,18 @@ class GradStorage(InternalStorage):
param.shape) + align <= self._max_size and id(
param) not in self._param_ids
def to(self, device, dtype=None, keep_alignment=True):
"""
Move the underlying buffer
"""
if self._release:
self.rebuild()
super().to(device, dtype)
if keep_alignment:
self._array_grads()
@fluid.dygraph.no_grad
def add_grad(self, param, align):
"""
......@@ -206,17 +264,25 @@ class GradStorage(InternalStorage):
"""
Given the parameter gradients which have been registered previously, rebuild the whole InternalStorage.
"""
assert len(self._params) > 0
if self._release:
self.buffer = paddle.zeros(
[self._max_size], dtype=self._params[0].dtype)
self.buffer = paddle.zeros([self._max_size], dtype=self._dtype)
for p in self._params:
self._add_grad_as_view(p, self._parm2align[p.name])
self._release = False
@fluid.dygraph.no_grad
def _array_grads(self):
"""
Given the parameters gradients which have been registered previously, rebuild the whole InternalStorage.
"""
if len(self._params) > 0:
self._fill = 0
for p in self._params:
self._add_grad_as_view(p, self._parm2align[p.name])
@fluid.dygraph.no_grad
def _add_grad_as_view(self, param, align):
assert np.prod(
......@@ -229,8 +295,17 @@ class GradStorage(InternalStorage):
assert offset <= np.prod(self.buffer.shape)
# Copy the current grad value to InternalStorage
assert self._device == "gpu"
tmp_var = core.VarBase(self.buffer._slice(self._fill, grad_end))
param._copy_gradient_from(tmp_var)
tmp_var.value().get_tensor()._clear()
dev_id = 0 if paddle.get_device() == "cpu" else int(paddle.get_device()
.split(":")[1])
if self._device == "cpu":
with device_guard(dev_id, self._device):
tmp_var = core.VarBase(self.buffer._slice(self._fill, grad_end))
param._copy_gradient_from(tmp_var)
tmp_var.value().get_tensor()._clear()
elif self._device == "gpu":
tmp_var = core.VarBase(self.buffer._slice(self._fill, grad_end))
param._copy_gradient_from(tmp_var)
tmp_var.value().get_tensor()._clear()
self._fill = offset
......@@ -30,6 +30,7 @@ from paddle.distributed.fleet.meta_parallel.sharding.sharding_stage2 import Shar
seed = 2021
epoch = 2
batch_size = 32
linear_size = 10000
strategy = fleet.DistributedStrategy()
strategy.hybrid_configs = {
......@@ -45,12 +46,12 @@ paddle.seed(seed)
class MLP(fluid.Layer):
def __init__(self, param_attr=None, bias_attr=None):
def __init__(self, linear_size=10000, param_attr=None, bias_attr=None):
super(MLP, self).__init__()
self._linear1 = Linear(10000, 10000)
self._linear2 = Linear(10000, 10000)
self._linear3 = Linear(10000, 10)
self._linear1 = Linear(linear_size, linear_size)
self._linear2 = Linear(linear_size, linear_size)
self._linear3 = Linear(linear_size, 10)
def forward(self, inputs):
y = self._linear1(inputs)
......@@ -59,10 +60,10 @@ class MLP(fluid.Layer):
return y
def reader_decorator():
def reader_decorator(linear_size=10000):
def __reader__():
for _ in range(100):
img = np.random.rand(10000).astype('float32')
img = np.random.rand(linear_size).astype('float32')
label = np.ones(1).astype('int64')
yield img, label
......@@ -120,6 +121,9 @@ def train_mlp(model,
use_multiprocess=True)
train_loader.set_sample_list_generator(train_reader)
if sharding_stage == 2:
model.to(device="gpu")
for eop in range(epoch):
model.train()
......@@ -153,9 +157,6 @@ def train_mlp(model,
if all_test and batch_id == 2:
return model.parameters()
if sharding_stage == 2:
model.to(device="gpu")
return model.parameters()
......
# -*- coding: UTF-8 -*-
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import argparse
import ast
import time
import paddle
import paddle.fluid as fluid
from paddle.fluid.dygraph.nn import Linear
from paddle.distributed import fleet
from paddle.fluid.dygraph import nn
from paddle.distributed.fleet.meta_optimizers.dygraph_optimizer.sharding_optimizer_stage2 import ShardingOptimizerStage2
from paddle.distributed.fleet.meta_parallel.sharding.sharding_stage2 import ShardingStage2
from paddle.distributed.fleet.meta_parallel.sharding.sharding_utils import ShardingScaler
from dygraph_sharding_stage2 import MLP, reader_decorator, optimizer_setting
seed = 2021
epoch = 2
batch_size = 32
linear_size = 8000
np.random.seed(seed)
paddle.seed(seed)
def train_mlp(model, offload=False):
group = paddle.distributed.new_group([0, 1])
optimizer = optimizer_setting(model=model, use_pure_fp16=True)
model = paddle.amp.decorate(models=model, level='O2', save_dtype='float32')
scaler = paddle.amp.GradScaler(init_loss_scaling=32768)
scaler = ShardingScaler(scaler, group)
optimizer = ShardingOptimizerStage2(
params=model.parameters(),
optim=optimizer,
group=group,
offload=offload)
model = ShardingStage2(model, optimizer, group=group, accumulate_grads=True)
train_reader = paddle.batch(
reader_decorator(linear_size), batch_size=batch_size, drop_last=True)
train_loader = paddle.io.DataLoader.from_generator(
capacity=32,
use_double_buffer=True,
iterable=True,
return_list=True,
use_multiprocess=True)
train_loader.set_sample_list_generator(train_reader)
for eop in range(epoch):
model.train()
for batch_id, data in enumerate(train_loader()):
img, label = data
label.stop_gradient = True
img.stop_gradient = True
with paddle.amp.auto_cast(True, level='O2'):
out = model(img)
loss = paddle.nn.functional.cross_entropy(
input=out, label=label)
avg_loss = paddle.mean(x=loss.cast(dtype=paddle.float32))
scaler.scale(avg_loss).backward()
model.grad_scale()
scaler.step(optimizer)
scaler.update()
model.clear_gradients()
for dtype in optimizer.param_storages:
for dst_rank, param_storage in optimizer.param_storages[dtype].items():
param_storage.to(device="gpu", dtype=dtype)
return model.parameters()
def test_sharding_stage2_offload():
mlp = MLP(linear_size)
mlp_offload = MLP(linear_size)
mlp_offload.set_state_dict(mlp.state_dict())
mlp_params = train_mlp(mlp, offload=False)
mlp_offload_params = train_mlp(mlp_offload, offload=True)
for i in range(len(mlp_params)):
for j in range(len(mlp_offload_params)):
if mlp_params[i].name == mlp_offload_params[j].name:
np.testing.assert_allclose(
mlp_params[i].numpy(),
mlp_offload_params[j].numpy(),
rtol=1e-6)
return
if __name__ == '__main__':
test_sharding_stage2_offload()
......@@ -26,6 +26,9 @@ class TestDygraphShardingStage2(TestMultipleGpus):
def test_dygraph_sharding_optimizer_stage2(self):
self.run_mnist_2gpu('dygraph_sharding_stage2.py')
def test_dygraph_sharding_optimizer_stage2_offload(self):
self.run_mnist_2gpu('dygraph_sharding_stage2_offload.py')
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册