未验证 提交 38d664b7 编写于 作者: W wz1qqx 提交者: GitHub

[XPU]Conv transpose fp16 && fix unittest (#53626)

* fix as review, add fp16 conv2d_transpose

* fix unittest of bn and reduce_mean

* fix bn unittest

* fix ci

* fix ci
上级 4f33f44b
......@@ -34,25 +34,6 @@ class Scope;
} // namespace framework
} // namespace paddle
namespace {
template <typename T1, typename T2>
void ConvertTensorType(phi::DenseTensor* tensor) {
phi::DenseTensor tmp_tensor;
tmp_tensor.set_type(phi::CppTypeToDataType<T2>::Type());
tmp_tensor.Resize(tensor->dims());
auto* tmp_data = tmp_tensor.mutable_data<T2>(paddle::platform::CPUPlace());
auto* data = tensor->mutable_data<T1>(paddle::platform::CPUPlace());
for (int i = 0; i < tensor->numel(); i++) {
tmp_data[i] = static_cast<T2>(data[i]);
}
tensor->clear();
paddle::framework::TensorCopySync(
tmp_tensor, paddle::platform::CPUPlace(), tensor);
}
} // namespace
namespace paddle {
namespace framework {
namespace ir {
......@@ -450,7 +431,7 @@ int Conv2dXPUFusePass::ApplyImpl(ir::Graph* graph,
// conv_filter fp16 --> fp32
auto tensor_type = filter_t->dtype();
if (tensor_type == phi::DataType::FLOAT16) {
ConvertTensorType<float16, float>(filter_t);
CastToFp32(filter_t, nullptr);
}
auto filter_dims = filter_t->dims();
bool has_bias = with_bn || with_conv_bias;
......@@ -529,9 +510,6 @@ int Conv2dXPUFusePass::ApplyImpl(ir::Graph* graph,
}
}
}
if (tensor_type == phi::DataType::FLOAT16) {
ConvertTensorType<float, float16>(filter_t);
}
// filter max
Node* filter_int16 = nullptr;
Node* filter_max = nullptr;
......
......@@ -164,7 +164,8 @@ XPUOpMap& get_kl2_ops() {
{"conv3d",
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
{"conv2d_transpose_grad", XPUKernelSet({phi::DataType::FLOAT32})},
{"conv2d_transpose", XPUKernelSet({phi::DataType::FLOAT32})},
{"conv2d_transpose",
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
{"cumsum",
XPUKernelSet({phi::DataType::FLOAT32,
phi::DataType::FLOAT16,
......
......@@ -67,5 +67,6 @@ PD_REGISTER_KERNEL(
#endif
#if defined(PADDLE_WITH_XPU)
PD_REGISTER_KERNEL(mean, XPU, ALL_LAYOUT, phi::MeanKernel, float) {}
PD_REGISTER_KERNEL(
mean, XPU, ALL_LAYOUT, phi::MeanKernel, float, phi::dtype::float16) {}
#endif
......@@ -195,13 +195,6 @@ void PowKernel(const Context& dev_ctx,
const DenseTensor& x,
const Scalar& factor,
DenseTensor* out) {
// using XPUType = typename XPUTypeTrait<T>::Type;
// // dev_ctx.template Alloc<T>(out);
// auto pow_factor = factor.to<T>();
// const auto* x_data = reinterpret_cast<const XPUType*>(x.data<T>());
// auto* y_data = reinterpret_cast<XPUType*>(dev_ctx.template Alloc<T>(out));
// // const T* x_data = x.data<T>();
// // T* y_data = out->data<T>();
dev_ctx.template Alloc<T>(out);
float pow_factor = factor.to<float>();
const T* x_data = x.data<T>();
......@@ -578,16 +571,11 @@ PD_REGISTER_KERNEL(
PD_REGISTER_ACTIVATION_KERNEL(exp, ExpKernel) // no grad
PD_REGISTER_ACTIVATION_KERNEL(floor, FloorKernel)
// PD_REGISTER_ACTIVATION_KERNEL(leaky_relu, LeakyReluKernel)
// PD_REGISTER_ACTIVATION_KERNEL(hard_sigmoid, HardSigmoidKernel)
PD_REGISTER_ACTIVATION_KERNEL(hardswish, HardSwishKernel)
PD_REGISTER_ACTIVATION_KERNEL(mish, MishKernel)
PD_REGISTER_ACTIVATION_KERNEL(pow, PowKernel)
PD_REGISTER_ACTIVATION_KERNEL(reciprocal, ReciprocalKernel)
PD_REGISTER_ACTIVATION_KERNEL(relu6_raw, Relu6RawKernel)
// PD_REGISTER_ACTIVATION_KERNEL(sigmoid, SigmoidKernel)
// PD_REGISTER_ACTIVATION_KERNEL(sqrt, SqrtKernel)
// PD_REGISTER_ACTIVATION_KERNEL(swish_raw, SwishRawKernel)
PD_REGISTER_ACTIVATION_KERNEL(softplus, SoftplusKernel)
PD_REGISTER_ACTIVATION_KERNEL(sin, SinKernel)
PD_REGISTER_ACTIVATION_KERNEL(cos, CosKernel)
......@@ -123,11 +123,11 @@ void Conv2dTransposeKernel(const Context& ctx,
true);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "conv2d_transpose_v2");
} else {
int r = xpu::conv2d_transpose_v2<float, float, float, int16_t>(
int r = xpu::conv2d_transpose_v2<XPUT, XPUT, XPUT, int16_t>(
ctx.x_context(),
x.data<float>(),
filter_.data<float>(),
out->data<float>(),
reinterpret_cast<const XPUT*>(x.data<T>()),
reinterpret_cast<const XPUT*>(filter.data<T>()),
reinterpret_cast<XPUT*>(out->data<T>()),
batch_size,
img_yc,
img_xh,
......@@ -148,5 +148,9 @@ void Conv2dTransposeKernel(const Context& ctx,
} // namespace phi
PD_REGISTER_KERNEL(
conv2d_transpose, XPU, ALL_LAYOUT, phi::Conv2dTransposeKernel, float) {}
PD_REGISTER_KERNEL(conv2d_transpose,
XPU,
ALL_LAYOUT,
phi::Conv2dTransposeKernel,
float,
phi::dtype::float16) {}
......@@ -154,7 +154,6 @@ class XPUTestBatchNormOp(XPUOpTestWrapper):
class TestBatchNormOp(unittest.TestCase):
def setUp(self):
self.op_type = "batch_norm"
self.dtype = np.float32
self.shape = [2, 3, 4, 5]
self.data_layout = "NCHW"
self.epsilon = 1e-05
......@@ -162,6 +161,9 @@ class XPUTestBatchNormOp(XPUOpTestWrapper):
self.init_dtype()
self.set_xpu()
self.set_attrs()
self.rtol = 1e-5
if self.dtype == np.float16:
self.rtol = 1e-2
if self.data_layout == "NHWC":
channel_size = self.shape[3]
......@@ -175,15 +177,15 @@ class XPUTestBatchNormOp(XPUOpTestWrapper):
np.random.seed(1024)
self.x_np = np.random.random_sample(self.shape).astype(self.dtype)
self.scale_np = np.random.random_sample([channel_size]).astype(
self.dtype
np.float32
)
self.bias_np = np.random.random_sample([channel_size]).astype(
self.dtype
np.float32
)
self.mean_np = np.zeros([channel_size]).astype(self.dtype)
self.variance_np = np.ones([channel_size]).astype(self.dtype)
self.saved_mean_np = np.zeros([channel_size]).astype(self.dtype)
self.saved_variance_np = np.ones([channel_size]).astype(self.dtype)
self.mean_np = np.zeros([channel_size]).astype(np.float32)
self.variance_np = np.ones([channel_size]).astype(np.float32)
self.saved_mean_np = np.zeros([channel_size]).astype(np.float32)
self.saved_variance_np = np.ones([channel_size]).astype(np.float32)
def set_attrs(self):
pass
......@@ -244,7 +246,110 @@ class XPUTestBatchNormOp(XPUOpTestWrapper):
self.epsilon,
self.data_layout,
)
np.testing.assert_allclose(y_np_ref, y_np, rtol=1e-05)
np.testing.assert_allclose(y_np_ref, y_np, rtol=self.rtol)
class TestBatchNormOpUseGlobalStats(unittest.TestCase):
def setUp(self):
self.places = [paddle.XPUPlace(0)]
self.init_test()
# train mode
def init_test(self):
self.use_global_stats = True
self.trainable_statistics = False
def test_global_stats(self):
for p in self.places:
with fluid.dygraph.guard(p):
x = paddle.randn([2, 6, 6, 4])
net1 = paddle.nn.BatchNorm(
6,
param_attr=fluid.ParamAttr(
initializer=paddle.nn.initializer.Constant(1.0)
),
use_global_stats=self.use_global_stats,
trainable_statistics=self.trainable_statistics,
)
net2 = paddle.nn.BatchNorm2D(
6, use_global_stats=self.use_global_stats
)
net2.weight = net1.weight
net2.bias = net1.bias
if self.trainable_statistics:
net1.training = False
net2.training = False
y1 = net1(x)
y2 = net2(x)
np.testing.assert_allclose(
y1.numpy(), y2.numpy(), rtol=1e-5
)
class TestBatchNormOpUseGlobalStats1(TestBatchNormOpUseGlobalStats):
# test mode
def init_test(self):
self.use_global_stats = True
self.trainable_statistics = True
class TestBatchNormUseGlobalStats2(TestBatchNormOpUseGlobalStats):
# train mode
def init_test(self):
self.use_global_stats = True
self.trainable_statistics = False
support_types = get_xpu_op_support_types('batch_norm')
for stype in support_types:
create_test_class(globals(), XPUTestBatchNormOp, stype)
class XPUTestBatchNormGradOp(XPUOpTestWrapper):
def __init__(self):
self.op_name = 'batch_norm'
self.use_dynamic_create_class = False
class TestBatchNormGradOp(unittest.TestCase):
def setUp(self):
self.op_type = "batch_norm"
self.shape = [2, 3, 4, 5]
self.data_layout = "NCHW"
self.epsilon = 1e-05
self.momentum = 0.9
self.init_dtype()
self.set_xpu()
self.set_attrs()
if self.data_layout == "NHWC":
channel_size = self.shape[3]
elif self.data_layout == "NCHW":
channel_size = self.shape[1]
else:
raise ValueError(
"Unsupported data layout! Only NCHW and NHWC is supported, but received "
+ self.data_layout
)
np.random.seed(1024)
self.x_np = np.random.random_sample(self.shape).astype(self.dtype)
self.scale_np = np.random.random_sample([channel_size]).astype(
np.float32
)
self.bias_np = np.random.random_sample([channel_size]).astype(
np.float32
)
self.mean_np = np.zeros([channel_size]).astype(np.float32)
self.variance_np = np.ones([channel_size]).astype(np.float32)
self.saved_mean_np = np.zeros([channel_size]).astype(np.float32)
self.saved_variance_np = np.ones([channel_size]).astype(np.float32)
def set_attrs(self):
pass
def init_dtype(self):
self.dtype = self.in_type
def set_xpu(self):
self.__class__.use_xpu = True
self.__class__.op_type = self.in_type
self.place = paddle.XPUPlace(0)
def test_train(self):
y_grad_np = np.random.random_sample(self.shape).astype(self.dtype)
......@@ -349,58 +454,10 @@ class XPUTestBatchNormOp(XPUOpTestWrapper):
outputs[name], outs[id], rtol=1e-05, atol=1e-4
)
class TestBatchNormOpUseGlobalStats(unittest.TestCase):
def setUp(self):
self.places = [paddle.XPUPlace(0)]
self.init_test()
# train mode
def init_test(self):
self.use_global_stats = True
self.trainable_statistics = False
def test_global_stats(self):
for p in self.places:
with fluid.dygraph.guard(p):
x = paddle.randn([2, 6, 6, 4])
net1 = paddle.nn.BatchNorm(
6,
param_attr=fluid.ParamAttr(
initializer=paddle.nn.initializer.Constant(1.0)
),
use_global_stats=self.use_global_stats,
trainable_statistics=self.trainable_statistics,
)
net2 = paddle.nn.BatchNorm2D(
6, use_global_stats=self.use_global_stats
)
net2.weight = net1.weight
net2.bias = net1.bias
if self.trainable_statistics:
net1.training = False
net2.training = False
y1 = net1(x)
y2 = net2(x)
np.testing.assert_allclose(
y1.numpy(), y2.numpy(), rtol=1e-05
)
class TestBatchNormOpUseGlobalStats1(TestBatchNormOpUseGlobalStats):
# test mode
def init_test(self):
self.use_global_stats = True
self.trainable_statistics = True
class TestBatchNormUseGlobalStats2(TestBatchNormOpUseGlobalStats):
# train mode
def init_test(self):
self.use_global_stats = True
self.trainable_statistics = False
support_types = get_xpu_op_support_types('batch_norm')
for stype in support_types:
create_test_class(globals(), XPUTestBatchNormOp, stype)
support_types_grad = get_xpu_op_support_types('batch_norm_grad')
for stype_grad in support_types_grad:
create_test_class(globals(), XPUTestBatchNormGradOp, stype_grad)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册