未验证 提交 a96c6dc7 编写于 作者: S sneaxiy 提交者: GitHub

Fix A100 CUDA12 ut (#54487)

* fix A100 CUDA12 ut

* fix ci uts

* fix test_sync_batch_norm_op

* fix sync bn op ut again by separating 2 files

* fix codestyle ci

* combine other PRs

* fix codestyle

* fix codestyle ci
上级 45ba9cf0
......@@ -277,7 +277,22 @@ PD_REGISTER_KERNEL(coalesce_tensor,
kernel->OutputAt(1).SetDataType(phi::DataType::UNDEFINED);
}
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
#ifdef PADDLE_WITH_CUDA
PD_REGISTER_KERNEL(coalesce_tensor,
GPU,
ALL_LAYOUT,
phi::CoalesceTensorKernel,
phi::dtype::float16,
phi::dtype::bfloat16,
int,
float,
double) {
kernel->InputAt(0).SetBackend(phi::Backend::ALL_BACKEND);
kernel->OutputAt(1).SetDataType(phi::DataType::UNDEFINED);
}
#endif
#ifdef PADDLE_WITH_HIP
PD_REGISTER_KERNEL(coalesce_tensor,
GPU,
ALL_LAYOUT,
......
......@@ -181,6 +181,23 @@ void SGDSparseParamSparseGradKernel(
} // namespace phi
#ifdef PADDLE_WITH_CUDA
PD_REGISTER_KERNEL(sgd,
GPU,
ALL_LAYOUT,
phi::SGDDenseKernel,
phi::dtype::float16,
phi::dtype::bfloat16,
float,
double) {
if (kernel_key.dtype() == phi::DataType::FLOAT16 ||
kernel_key.dtype() == phi::DataType::BFLOAT16) {
kernel->OutputAt(1).SetDataType(phi::DataType::FLOAT32);
}
}
#endif
#ifdef PADDLE_WITH_HIP
PD_REGISTER_KERNEL(sgd,
GPU,
ALL_LAYOUT,
......@@ -192,6 +209,7 @@ PD_REGISTER_KERNEL(sgd,
kernel->OutputAt(1).SetDataType(phi::DataType::FLOAT32);
}
}
#endif
PD_REGISTER_KERNEL(sgd_dense_param_sparse_grad,
GPU,
......
......@@ -888,7 +888,7 @@ def conv2d(
"""
check_variable_and_dtype(
input, 'input', ['float16', 'float32', 'float64'], 'conv2d'
input, 'input', ['uint16', 'float16', 'float32', 'float64'], 'conv2d'
)
if len(input.shape) != 4:
raise ValueError(
......@@ -2739,12 +2739,15 @@ def batch_norm(
helper = LayerHelper('batch_norm', **locals())
check_variable_and_dtype(
input, 'input', ['float16', 'float32', 'float64'], 'batch_norm'
input,
'input',
['uint16', 'float16', 'float32', 'float64'],
'batch_norm',
)
dtype = helper.input_dtype()
# use fp32 for bn parameter
if dtype == core.VarDesc.VarType.FP16:
if dtype == core.VarDesc.VarType.FP16 or dtype == core.VarDesc.VarType.BF16:
dtype = core.VarDesc.VarType.FP32
input_shape = input.shape
......
......@@ -36,8 +36,6 @@ class TestShardingPass(unittest.TestCase):
+ [
"-m",
"paddle.distributed.launch",
"--devices",
"0,1",
"--log_dir",
tmp_dir.name,
launch_model_path,
......
......@@ -13,8 +13,6 @@
# limitations under the License.
import os
import pickle
import sys
import test_collective_api_base as test_base
......@@ -148,7 +146,7 @@ class TestCollectiveAllgatherAPI(test_base.TestCollectiveAPIRunnerBase):
out = exe.run(
train_prog, feed={'tindata': indata}, fetch_list=fetch_list
)
sys.stdout.buffer.write(pickle.dumps(out))
test_base.dump_output(out)
if __name__ == "__main__":
......
......@@ -13,12 +13,11 @@
# limitations under the License.
import os
import pickle
import sys
import numpy as np
from legacy_test.test_collective_api_base import (
TestCollectiveAPIRunnerBase,
dump_output,
runtime_main,
)
......@@ -124,7 +123,7 @@ class TestCollectiveGlobalGatherAPI(TestCollectiveAPIRunnerBase):
fetch_list=fetch_list,
)
sys.stdout.buffer.write(pickle.dumps(out))
dump_output(out)
if __name__ == "__main__":
......
......@@ -13,12 +13,11 @@
# limitations under the License.
import os
import pickle
import sys
import numpy as np
from legacy_test.test_collective_api_base import (
TestCollectiveAPIRunnerBase,
dump_output,
runtime_main,
)
......@@ -103,7 +102,7 @@ class TestCollectiveGlobalScatterAPI(TestCollectiveAPIRunnerBase):
fetch_list=fetch_list,
)
sys.stdout.buffer.write(pickle.dumps(out))
dump_output(out)
if __name__ == "__main__":
......
......@@ -18,8 +18,8 @@ import random
import numpy as np
from legacy_test.test_dist_base import (
TestParallelDyGraphRunnerBase,
dump_output,
print_to_err,
print_to_out,
runtime_main,
)
......@@ -92,7 +92,7 @@ class TestNoSync(TestParallelDyGraphRunnerBase):
)
print_to_err(type(self).__name__, "model built in dygraph")
out_losses = self.model_train(args, model, opt, train_reader)
print_to_out(out_losses)
dump_output(out_losses)
return out_losses
def run_trainer_with_spawn_func(self, args):
......@@ -120,7 +120,7 @@ class TestNoSync(TestParallelDyGraphRunnerBase):
)
out_losses = self.model_train(args, model, opt, train_reader)
print_to_out(out_losses)
dump_output(out_losses)
return out_losses
def model_train(self, args, model, opt, train_reader):
......
......@@ -67,7 +67,6 @@ class TestCollectiveReduceAPI(TestDistBase):
def test_reduce_gloo_with_comm_context(self):
dtypes_to_test = [
"float16",
"float32",
"float64",
"int32",
......@@ -115,7 +114,6 @@ class TestCollectiveReduceAPI(TestDistBase):
def test_reduce_gloo_dygraph(self):
dtypes_to_test = [
"float16",
"float32",
"float64",
"int32",
......
......@@ -15,7 +15,6 @@
import os
from dist_mnist import cnn_model # noqa: F401
from test_dist_base import dump_output
import paddle
from paddle import fluid
......@@ -28,8 +27,12 @@ fluid.default_main_program().random_seed = 1
def runtime_main():
from test_dist_base import dump_output
from paddle.distributed import fleet
paddle.enable_static()
# model definition
train_prog = paddle.fluid.Program()
startup_prog = paddle.fluid.Program()
......
......@@ -80,6 +80,12 @@ def create_pyobject_test_data(shape=None, seed=None):
return [list_data, dict_data]
def dump_output(x):
dump_file = os.environ['DUMP_FILE']
with open(dump_file, 'wb') as f:
pickle.dump(x, f)
def create_test_data(shape=None, dtype=None, seed=None):
assert shape, "Shape should be specified"
if dtype == "float32" or dtype == "float16" or dtype == "float64":
......@@ -160,7 +166,7 @@ class TestCollectiveAPIRunnerBase:
else:
out = self.get_model(train_prog, startup_prog, rank, indata)
# print(out, sys.stderr)
sys.stdout.buffer.write(pickle.dumps(out))
dump_output(out)
def runtime_main(test_class, col_type):
......@@ -255,6 +261,13 @@ class TestDistBase(unittest.TestCase):
# update environment
env0.update(envs)
env1.update(envs)
cur_pid = os.getpid()
dump_file_0 = f'./out_data_0_{cur_pid}.pickled'
dump_file_1 = f'./out_data_1_{cur_pid}.pickled'
env0['DUMP_FILE'] = dump_file_0
env1['DUMP_FILE'] = dump_file_1
if os.getenv('WITH_COVERAGE', 'OFF') == 'ON':
tr_cmd = "%s -m coverage run --branch -p %s"
else:
......@@ -295,9 +308,16 @@ class TestDistBase(unittest.TestCase):
sys.stderr.write('trainer 0 stderr file: %s\n' % f.read())
with open(path1, "r") as f:
sys.stderr.write('trainer 1 stderr file: %s\n' % f.read())
def load_and_remove(path):
with open(path, 'rb') as f:
out = pickle.load(f)
os.remove(path)
return out
return (
pickle.loads(tr0_out),
pickle.loads(tr1_out),
load_and_remove(dump_file_0),
load_and_remove(dump_file_1),
tr0_proc.pid,
tr1_proc.pid,
)
......@@ -469,7 +489,7 @@ class TestDistBase(unittest.TestCase):
elif col_type == "column_parallel_linear":
result_data = tr0_out[0]
np.random.seed(2020)
weight = np.random.rand(1000, 16)
weight = np.random.rand(1000, 16).astype(np.float32)
need_result = np.matmul(input1, weight)
np.testing.assert_allclose(
result_data, need_result, rtol=1e-05, atol=1e-05
......
......@@ -126,7 +126,9 @@ class TestCollectiveRunnerBase:
out = exe.run(
train_prog, feed={'tindata': indata}, fetch_list=[result.name]
)
sys.stdout.buffer.write(pickle.dumps(out))
dump_file = os.environ['DUMP_FILE']
with open(dump_file, 'wb') as f:
pickle.dump(out, f)
def runtime_main(test_class, col_type, sub_type):
......@@ -189,9 +191,17 @@ class TestDistBase(unittest.TestCase):
"PADDLE_TRAINER_ENDPOINTS": self._ps_endpoints,
"PADDLE_CURRENT_ENDPOINT": w1_ep,
}
cur_pid = os.getpid()
dump_file_0 = f'./out_data_0_{cur_pid}.pickled'
dump_file_1 = f'./out_data_1_{cur_pid}.pickled'
# update environment
env0.update(envs)
env1.update(envs)
env0['DUMP_FILE'] = dump_file_0
env1['DUMP_FILE'] = dump_file_1
tr_cmd = "%s %s"
tr0_cmd = tr_cmd % (self._python_interp, model_file)
tr1_cmd = tr_cmd % (self._python_interp, model_file)
......@@ -221,9 +231,16 @@ class TestDistBase(unittest.TestCase):
# close trainer file
tr0_pipe.close()
tr1_pipe.close()
def load_and_remove(path):
with open(path, 'rb') as f:
out = pickle.load(f)
os.remove(path)
return out
return (
pickle.loads(tr0_out),
pickle.loads(tr1_out),
load_and_remove(dump_file_0),
load_and_remove(dump_file_1),
tr0_proc.pid,
tr1_proc.pid,
)
......
......@@ -44,8 +44,31 @@ DEFAULT_BATCH_SIZE = 2
DIST_UT_PORT = 0
def print_to_out(out_losses):
sys.stdout.buffer.write(pickle.dumps(out_losses))
def remove_glog_envs(envs):
if not envs:
return envs
glog_envs = ['GLOG_v', 'GLOG_logtostderr', 'GLOG_vmodule']
envs = dict(envs)
for env in glog_envs:
if env in envs:
del envs[env]
return envs
def get_dump_file(rank):
return f"./out_dump_{os.getpid()}_{rank}.pickled"
def modify_envs(envs, rank=0):
if not envs:
envs = {}
envs = remove_glog_envs(envs)
dump_file = get_dump_file(rank)
envs['DUMP_FILE'] = dump_file
if os.path.exists(dump_file):
os.remove(dump_file)
return envs
def dump_output(x):
......@@ -54,7 +77,8 @@ def dump_output(x):
pickle.dump(x, f)
def load_and_remove(path):
def load_and_remove_dump_file(rank=0):
path = get_dump_file(rank)
with open(path, 'rb') as f:
out = pickle.load(f)
os.remove(path)
......@@ -1084,14 +1108,14 @@ class TestDistBase(unittest.TestCase):
ps0_cmd.strip().split(" "),
stdout=subprocess.PIPE,
stderr=ps0_pipe,
env=required_envs,
env=modify_envs(required_envs),
)
print_to_err(type(self).__name__, "going to start pserver process 1")
ps1_proc = subprocess.Popen(
ps1_cmd.strip().split(" "),
stdout=subprocess.PIPE,
stderr=ps1_pipe,
env=required_envs,
env=modify_envs(required_envs),
)
return ps0_proc, ps1_proc, ps0_pipe, ps1_pipe
......@@ -1152,9 +1176,6 @@ class TestDistBase(unittest.TestCase):
cmd += " --find_unused_parameters"
env_local.update(envs)
cur_pid = os.getpid()
dump_file = f"out_data_local_{cur_pid}.pickled"
env_local["DUMP_FILE"] = dump_file
print(f"local_cmd: {cmd}, env: {env_local}")
if check_error_log:
......@@ -1164,14 +1185,14 @@ class TestDistBase(unittest.TestCase):
cmd.split(" "),
stdout=subprocess.PIPE,
stderr=err_log,
env=env_local,
env=modify_envs(env_local),
)
else:
local_proc = subprocess.Popen(
cmd.split(" "),
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
env=env_local,
env=modify_envs(env_local),
)
local_out, local_err = local_proc.communicate()
......@@ -1181,7 +1202,7 @@ class TestDistBase(unittest.TestCase):
sys.stderr.write('local_stderr: %s\n' % local_err)
return load_and_remove(dump_file)
return load_and_remove_dump_file()
def _run_local_gloo(
self,
......@@ -1260,14 +1281,6 @@ class TestDistBase(unittest.TestCase):
env0.update(envs)
env1.update(envs)
cur_pid = os.getpid()
dump_files = [
f'./out_data_0_{cur_pid}.pickled',
f'./out_data_1_{cur_pid}.pickled',
]
env0["DUMP_FILE"] = dump_files[0]
env1["DUMP_FILE"] = dump_files[1]
print(f"tr0_cmd: {tr0_cmd}, env: {env0}")
print(f"tr1_cmd: {tr1_cmd}, env: {env1}")
......@@ -1281,14 +1294,14 @@ class TestDistBase(unittest.TestCase):
tr0_cmd.strip().split(" "),
stdout=subprocess.PIPE,
stderr=tr0_pipe,
env=env0,
env=modify_envs(env0, 0),
)
print_to_err(type(self).__name__, "going to start trainer process 1")
tr1_proc = subprocess.Popen(
tr1_cmd.strip().split(" "),
stdout=subprocess.PIPE,
stderr=tr1_pipe,
env=env1,
env=modify_envs(env1, 1),
)
# Wait until trainer process terminate
......@@ -1315,7 +1328,7 @@ class TestDistBase(unittest.TestCase):
ps0.terminate()
ps1.terminate()
return load_and_remove(dump_files[0]), load_and_remove(dump_files[1])
return load_and_remove_dump_file(0), load_and_remove_dump_file(1)
def _get_gloo_trainer_cmd(
self, model, ep, update_method, trainer_id, trainer_num
......@@ -1502,8 +1515,6 @@ class TestDistBase(unittest.TestCase):
procs = []
pipes = []
dump_files = []
cur_pid = os.getpid()
for i in range(0, trainer_num):
tr_cmd, tr_env = self._get_gloo_trainer_cmd(
model, worker_endpoints[i], update_method, i, trainer_num
......@@ -1511,10 +1522,6 @@ class TestDistBase(unittest.TestCase):
tr_env.update(envs)
tr_env["GLOG_vmodule"] = 'gloo_context=4'
tr_env["GLOG_v"] = '3'
dump_file = f'./out_data_{i}_{cur_pid}.pickled'
dump_files.append(dump_file)
tr_env["DUMP_FILE"] = dump_file
print(
"use_hallreduce:{} tr_cmd:{}, env: {}".format(
self._use_hallreduce, tr_cmd, tr_env
......@@ -1534,7 +1541,7 @@ class TestDistBase(unittest.TestCase):
tr_cmd.strip().split(" "),
stdout=subprocess.PIPE,
stderr=tr_pipe,
env=tr_env,
env=modify_envs(tr_env, i),
)
procs.append(tr_proc)
......@@ -1550,15 +1557,13 @@ class TestDistBase(unittest.TestCase):
if trainer_num == 1:
if check_error_log:
print("outs[0]:", outs[0])
return load_and_remove(dump_files[0])
return load_and_remove_dump_file(0)
else:
if check_error_log:
print("outs[0]:", outs[0])
print("outs[1]:", outs[1])
return load_and_remove(dump_files[0]), load_and_remove(
dump_files[1]
)
return load_and_remove_dump_file(0), load_and_remove_dump_file(1)
def _run_cluster_nccl2(
self, model, envs, update_method, check_error_log, log_name
......@@ -1586,16 +1591,11 @@ class TestDistBase(unittest.TestCase):
procs = []
pipes = []
cur_pid = os.getpid()
dump_files = []
for i in range(0, trainer_num):
tr_cmd, tr_env = self._get_nccl2_trainer_cmd(
model, worker_endpoints[i], update_method, i, trainer_num
)
tr_env.update(envs)
dump_file = f'./out_data_{i}_{cur_pid}.pickled'
dump_files.append(dump_file)
tr_env["DUMP_FILE"] = dump_file
print(
"use_hallreduce:{} tr_cmd:{}, env: {}".format(
self._use_hallreduce, tr_cmd, tr_env
......@@ -1615,7 +1615,7 @@ class TestDistBase(unittest.TestCase):
tr_cmd.strip().split(" "),
stdout=subprocess.PIPE,
stderr=tr_pipe,
env=tr_env,
env=modify_envs(tr_env, i),
)
procs.append(tr_proc)
......@@ -1632,7 +1632,7 @@ class TestDistBase(unittest.TestCase):
print("outs[0]:", outs[0])
print("outs[1]:", outs[1])
return load_and_remove(dump_files[0]), load_and_remove(dump_files[1])
return load_and_remove_dump_file(0), load_and_remove_dump_file(1)
def _run_pipeline(self, model, envs, check_error_log, log_name):
# NOTE: we reuse ps_endpoints as nccl2 worker endpoints
......@@ -1643,8 +1643,6 @@ class TestDistBase(unittest.TestCase):
procs = []
pipes = []
cur_pid = os.getpid()
dump_files = []
for i in range(0, trainer_num):
tr_cmd, tr_env = self._get_nccl2_trainer_cmd(
model, worker_endpoints[i], update_method, i, trainer_num
......@@ -1654,10 +1652,6 @@ class TestDistBase(unittest.TestCase):
tr_env['NCCL_SHM_DISABLE'] = '1'
tr_env['FLAGS_selected_gpus'] = str(i)
tr_env['FLAGS_cudnn_deterministic'] = '0'
dump_file = f'./out_data_{i}_{cur_pid}.pickled'
dump_files.append(dump_file)
tr_env["DUMP_FILE"] = dump_file
print(f"tr_cmd:{tr_cmd}, env: {tr_env}")
path = os.path.join(self.temp_dir.name + f"tr{i}_err.log")
......@@ -1671,7 +1665,7 @@ class TestDistBase(unittest.TestCase):
tr_cmd.strip().split(" "),
stdout=subprocess.PIPE,
stderr=tr_pipe,
env=tr_env,
env=modify_envs(tr_env, i),
)
procs.append(tr_proc)
......@@ -1687,7 +1681,7 @@ class TestDistBase(unittest.TestCase):
if check_error_log:
print("outs[0]:", outs[0])
print("outs[1]:", outs[1])
return load_and_remove(dump_files[0]), load_and_remove(dump_files[1])
return load_and_remove_dump_file(0), load_and_remove_dump_file(1)
def _get_required_envs(self, check_error_log=False, need_envs={}):
# TODO(typhoonzero): should auto adapt GPU count on the machine.
......
......@@ -18,9 +18,11 @@ for both FP64 and FP16 input.
import os
import random
import subprocess
import shutil
import sys
import tempfile
import unittest
from shlex import quote
import numpy as np
from decorator_helper import prog_scope
......@@ -33,10 +35,41 @@ from eager_op_test import (
import paddle
from paddle import fluid, nn
from paddle.fluid import Program, core, program_guard
from paddle.fluid.framework import in_dygraph_mode
_set_use_system_allocator(True)
def enable_static():
if in_dygraph_mode():
paddle.enable_static()
def cleanup():
paddle.disable_static()
else:
def cleanup():
pass
return cleanup
def convert_numpy_array(array):
if array.dtype != np.uint16:
return array
cleanup = None
if not in_dygraph_mode():
paddle.disable_static()
cleanup = lambda: paddle.enable_static()
out = paddle.to_tensor(array).astype(paddle.float32).numpy()
if cleanup is not None:
cleanup()
return out
def create_or_get_tensor(scope, var_name, var, place):
"""Get tensor, if not found, create a new one."""
tensor = scope.var(var_name).get_tensor()
......@@ -47,6 +80,24 @@ def create_or_get_tensor(scope, var_name, var, place):
return tensor
def clean_dir(path):
if isinstance(path, tempfile.TemporaryDirectory):
path = path.name
for f in os.listdir(path):
f = os.path.join(path, f)
if os.path.isdir(f):
shutil.rmtree(f)
else:
os.remove(f)
def concat_cmd(cmd):
if isinstance(cmd, str):
return cmd
return ' '.join([quote(c) for c in cmd])
class TestSyncBatchNormOpTraining(unittest.TestCase):
"""sync_batch_norm op test."""
......@@ -69,7 +120,7 @@ class TestSyncBatchNormOpTraining(unittest.TestCase):
def multi_device_run(self, layout, fetch_list, only_forward=False):
cmds = [
"python",
sys.executable,
"-m",
"paddle.distributed.launch",
]
......@@ -91,8 +142,8 @@ class TestSyncBatchNormOpTraining(unittest.TestCase):
cmds += ["--only_forward"]
if self.dtype == np.float16 or self.dtype == np.uint16:
cmds += ["--use_cudnn"]
p = subprocess.run(cmds)
assert p.returncode == 0, f"Fleet train: Failed: {p}"
cmd = concat_cmd(cmds)
assert os.system(cmd) == 0, cmd
def _build_program(
self, place, layout, seed, sync_bn=False, only_forward=False
......@@ -143,9 +194,18 @@ class TestSyncBatchNormOpTraining(unittest.TestCase):
@prog_scope()
def _compare(self, place, layout, only_forward):
try:
with paddle.utils.unique_name.guard():
self._compare_impl(place, layout, only_forward)
finally:
clean_dir(self.data_dir)
clean_dir(self.fleet_log_dir)
def _compare_impl(self, place, layout, only_forward):
"""Compare results."""
seed = 10
os.environ['FLAGS_cudnn_deterministic'] = "1"
paddle.set_flags({'FLAGS_cudnn_deterministic': 1})
paddle.enable_static()
scope = core.Scope()
if self.dtype == np.uint16:
......@@ -234,8 +294,8 @@ class TestSyncBatchNormOpTraining(unittest.TestCase):
if sync_bn_val.shape != bn_val.shape:
bn_val = bn_val[:stride]
np.testing.assert_allclose(
bn_val,
sync_bn_val,
convert_numpy_array(bn_val),
convert_numpy_array(sync_bn_val),
rtol=1e-05,
atol=self.atol,
err_msg='Output ('
......@@ -311,6 +371,7 @@ class TestDygraphSyncBatchNormAPIError(unittest.TestCase):
if not core.is_compiled_with_cuda():
return
cleanup = enable_static()
with program_guard(Program(), Program()):
my_sync_batch_norm = paddle.nn.SyncBatchNorm(10)
x1 = fluid.create_lod_tensor(
......@@ -325,6 +386,7 @@ class TestDygraphSyncBatchNormAPIError(unittest.TestCase):
)
x2.desc.set_need_check_feed(False)
self.assertRaises(TypeError, my_sync_batch_norm, x2)
cleanup()
class TestConvertSyncBatchNorm(unittest.TestCase):
......@@ -384,71 +446,6 @@ class TestConvertSyncBatchNormCast1(unittest.TestCase):
self.assertEqual(len(compare_model.sublayers()), len(model.sublayers()))
class TestConvertSyncBatchNormCase2(unittest.TestCase):
def test_convert(self):
if not core.is_compiled_with_cuda():
return
with fluid.dygraph.guard(fluid.CUDAPlace(0)):
class SyBNNet(paddle.nn.Layer):
def __init__(self, in_ch=3, out_ch=3, dirate=1):
super().__init__()
self.bn_s1 = paddle.nn.SyncBatchNorm.convert_sync_batchnorm(
paddle.nn.BatchNorm3D(
out_ch,
weight_attr=paddle.ParamAttr(
regularizer=paddle.regularizer.L2Decay(0.0)
),
)
)
self.bn_s2 = paddle.nn.SyncBatchNorm.convert_sync_batchnorm(
paddle.nn.BatchNorm3D(out_ch, data_format='NDHWC')
)
def forward(self, x):
x = self.bn_s1(x)
out = paddle.sum(paddle.abs(self.bn_s2(x)))
return out
class BNNet(paddle.nn.Layer):
def __init__(self, in_ch=3, out_ch=3, dirate=1):
super().__init__()
self.bn_s1 = paddle.nn.BatchNorm3D(
out_ch,
weight_attr=paddle.ParamAttr(
regularizer=paddle.regularizer.L2Decay(0.0)
),
)
self.bn_s2 = paddle.nn.SyncBatchNorm.convert_sync_batchnorm(
paddle.nn.BatchNorm3D(out_ch, data_format='NDHWC')
)
def forward(self, x):
x = self.bn_s1(x)
out = paddle.sum(paddle.abs(self.bn_s2(x)))
return out
bn_model = BNNet()
sybn_model = SyBNNet()
np.random.seed(10)
data = np.random.random([3, 3, 3, 3, 3]).astype('float32')
x = paddle.to_tensor(data)
bn_out = bn_model(x)
sybn_out = sybn_model(x)
np.testing.assert_allclose(
bn_out.numpy(),
sybn_out.numpy(),
rtol=1e-05,
err_msg='Output has diff. \n'
+ '\nBN '
+ str(bn_out.numpy())
+ '\n'
+ 'Sync BN '
+ str(sybn_out.numpy()),
)
class TestDygraphSyncBatchNormDataFormatError(unittest.TestCase):
def test_errors(self):
if not core.is_compiled_with_cuda():
......
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import paddle
class SyBNNet(paddle.nn.Layer):
def __init__(self, in_ch=3, out_ch=3, dirate=1):
super().__init__()
self.bn_s1 = paddle.nn.SyncBatchNorm.convert_sync_batchnorm(
paddle.nn.BatchNorm3D(
out_ch,
weight_attr=paddle.ParamAttr(
regularizer=paddle.regularizer.L2Decay(0.0)
),
)
)
self.bn_s2 = paddle.nn.SyncBatchNorm.convert_sync_batchnorm(
paddle.nn.BatchNorm3D(out_ch, data_format='NDHWC')
)
def forward(self, x):
x = self.bn_s1(x)
out = paddle.sum(paddle.abs(self.bn_s2(x)))
return out
class BNNet(paddle.nn.Layer):
def __init__(self, in_ch=3, out_ch=3, dirate=1):
super().__init__()
self.bn_s1 = paddle.nn.BatchNorm3D(
out_ch,
weight_attr=paddle.ParamAttr(
regularizer=paddle.regularizer.L2Decay(0.0)
),
)
self.bn_s2 = paddle.nn.SyncBatchNorm.convert_sync_batchnorm(
paddle.nn.BatchNorm3D(out_ch, data_format='NDHWC')
)
def forward(self, x):
x = self.bn_s1(x)
out = paddle.sum(paddle.abs(self.bn_s2(x)))
return out
class TestConvertSyncBatchNormCase(unittest.TestCase):
def test_convert(self):
if not paddle.is_compiled_with_cuda():
return
bn_model = BNNet()
sybn_model = SyBNNet()
np.random.seed(10)
data = np.random.random([3, 3, 3, 3, 3]).astype('float32')
x = paddle.to_tensor(data)
bn_out = bn_model(x)
sybn_out = sybn_model(x)
np.testing.assert_allclose(
bn_out.numpy(),
sybn_out.numpy(),
rtol=1e-05,
err_msg='Output has diff. \n'
+ '\nBN '
+ str(bn_out.numpy())
+ '\n'
+ 'Sync BN '
+ str(sybn_out.numpy()),
)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册