未验证 提交 4c373e6b 编写于 作者: L Leo Guo 提交者: GitHub

Fix bugs and add unit tests in instance_norm_grad_kernel when d_scale and (#50394)

d_bias are nullptr. Modify the code style of full_kernel.cc. Add new data
type for concat, elementwise_add, gather, scale, scatter ops. test=kunlun
上级 243cae59
......@@ -118,7 +118,8 @@ XPUOpMap& get_kl2_ops() {
{"concat",
XPUKernelSet({phi::DataType::FLOAT32,
phi::DataType::FLOAT16,
phi::DataType::INT64})},
phi::DataType::INT64,
phi::DataType::INT32})},
{"conv2d_grad",
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
{"conv2d",
......@@ -159,7 +160,10 @@ XPUOpMap& get_kl2_ops() {
{"elementwise_add_grad",
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
{"elementwise_add",
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
XPUKernelSet({phi::DataType::FLOAT32,
phi::DataType::FLOAT16,
phi::DataType::INT64,
phi::DataType::INT32})},
{"elementwise_div_grad",
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
{"elementwise_div",
......@@ -300,7 +304,11 @@ XPUOpMap& get_kl2_ops() {
phi::DataType::INT64,
phi::DataType::FLOAT32})},
{"gather",
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
XPUKernelSet({phi::DataType::FLOAT32,
phi::DataType::FLOAT16,
phi::DataType::INT32,
phi::DataType::INT64,
phi::DataType::BOOL})},
{"gaussian_random", XPUKernelSet({phi::DataType::FLOAT32})},
{"gelu_grad",
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
......@@ -491,7 +499,8 @@ XPUOpMap& get_kl2_ops() {
{"scale",
XPUKernelSet({phi::DataType::FLOAT32,
phi::DataType::FLOAT16,
phi::DataType::INT64})},
phi::DataType::INT64,
phi::DataType::INT32})},
{"scatter",
XPUKernelSet({phi::DataType::INT64,
phi::DataType::INT32,
......
......@@ -71,7 +71,7 @@ PD_REGISTER_KERNEL(full_sr,
phi::dtype::complex<double>) {}
#endif
#if defined(PADDLE_WITH_XPU) && !defined(PADDLE_WITH_XPU_KP)
#if defined(PADDLE_WITH_XPU)
PD_REGISTER_KERNEL(full_sr,
XPU,
ALL_LAYOUT,
......
......@@ -117,4 +117,5 @@ PD_REGISTER_KERNEL(concat,
phi::ConcatKernel,
float,
phi::dtype::float16,
int64_t) {}
int64_t,
int) {}
......@@ -75,5 +75,11 @@ PD_REGISTER_KERNEL(grad_add,
phi::GradAddXPUKernel,
phi::dtype::float16,
float) {}
PD_REGISTER_KERNEL(
add_raw, XPU, ALL_LAYOUT, phi::AddRawKernel, phi::dtype::float16, float) {}
PD_REGISTER_KERNEL(add_raw,
XPU,
ALL_LAYOUT,
phi::AddRawKernel,
phi::dtype::float16,
float,
int,
int64_t) {}
......@@ -28,32 +28,6 @@
namespace phi {
template <typename InType, typename OutType>
void TensorSetConstantXPU(phi::DenseTensor* tensor,
InType value,
phi::Place place) {
auto* begin = tensor->mutable_data<OutType>(place);
int64_t numel = tensor->numel();
std::unique_ptr<OutType[]> data_cpu(new OutType[numel]);
std::fill(
data_cpu.get(), data_cpu.get() + numel, static_cast<OutType>(value));
paddle::memory::Copy(place,
begin,
phi::CPUPlace(),
static_cast<void*>(data_cpu.get()),
numel * sizeof(OutType));
}
template <typename T, typename Context, typename VType>
void FullValueXPU(const Context& dev_ctx, DenseTensor* tensor, VType val) {
dev_ctx.template Alloc<T>(tensor);
PD_VISIT_ALL_TYPES(tensor->dtype(), "FullValueXPU", ([&] {
TensorSetConstantXPU<VType, data_t>(
tensor, val, dev_ctx.GetPlace());
}));
}
template <typename T, typename Context>
void FullKernel(const Context& dev_ctx,
const IntArray& shape,
......@@ -64,13 +38,12 @@ void FullKernel(const Context& dev_ctx,
out->Resize(phi::make_ddim(shape.GetData()));
int numel = out->numel();
dev_ctx.template Alloc<T>(out);
auto value = val.to<double>();
auto out_data = reinterpret_cast<XPUInTDType*>(out->data<T>());
if (numel > 0) {
int r = xpu::constant(dev_ctx.x_context(),
out_data,
out->numel(),
static_cast<XPUInTDType>(value));
static_cast<XPUInTDType>(val.to<T>()));
PADDLE_ENFORCE_XDNN_SUCCESS(r, "constant");
}
}
......
......@@ -83,5 +83,12 @@ void GatherKernel(const Context& dev_ctx,
} // namespace phi
PD_REGISTER_KERNEL(
gather, XPU, ALL_LAYOUT, phi::GatherKernel, float, phi::dtype::float16) {}
PD_REGISTER_KERNEL(gather,
XPU,
ALL_LAYOUT,
phi::GatherKernel,
float,
phi::dtype::float16,
int,
int64_t,
bool) {}
......@@ -15,6 +15,7 @@
#include "paddle/phi/kernels/instance_norm_grad_kernel.h"
#include "paddle/phi/backends/xpu/enforce_xpu.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/norm_utils.h"
namespace phi {
......@@ -24,46 +25,89 @@ void InstanceNormGradKernel(const Context& dev_ctx,
const paddle::optional<DenseTensor>& scale,
const DenseTensor& saved_mean,
const DenseTensor& saved_variance,
const DenseTensor& y_grad,
const DenseTensor& d_y,
float epsilon,
DenseTensor* x_grad,
DenseTensor* scale_grad,
DenseTensor* bias_grad) {
DenseTensor* d_x,
DenseTensor* d_scale,
DenseTensor* d_bias) {
using XPUType = typename XPUTypeTrait<T>::Type;
const auto& x_dims = x.dims();
int n = x_dims[0];
int c = x_dims[1];
int h = x_dims[2];
int w = x_dims[3];
int N, C, H, W, D;
funcs::ExtractNCWHD(x_dims, DataLayout::kNCHW, &N, &C, &H, &W, &D);
PADDLE_ENFORCE_EQ(
x_dims.size() <= 5 && D == 1,
true,
phi::errors::InvalidArgument(
"The size of input's dimensions should be less equal than 5",
"and the dimension of D should be eaual to 1",
"But received: the size of input's dimensions is [%d]",
x_dims.size()));
dev_ctx.template Alloc<T>(x_grad);
if (bias_grad != nullptr) {
dev_ctx.template Alloc<float>(bias_grad);
}
if (scale_grad != nullptr) {
dev_ctx.template Alloc<float>(scale_grad);
dev_ctx.template Alloc<T>(d_x);
T* d_scale_data = nullptr;
T* d_bias_data = nullptr;
if (d_scale && d_bias) {
dev_ctx.template Alloc<float>(d_scale);
dev_ctx.template Alloc<float>(d_bias);
d_scale_data = d_scale->data<float>();
d_bias_data = d_bias->data<float>();
}
const auto scale_ptr = scale.get_ptr();
if (scale_ptr) {
PADDLE_ENFORCE_EQ(
scale_ptr->dims().size(),
1UL,
phi::errors::InvalidArgument(
"The `shape` in InstanceNormOp is invalid: "
"the size of scale's dimensions must be equal to 1. But "
"received: the size of scale's dimensions"
"is [%d]",
scale_ptr->dims().size()));
PADDLE_ENFORCE_EQ(scale_ptr->dims()[0],
C,
phi::errors::InvalidArgument(
"The `shape` in InstanceNormOp is invalid: "
"the first dimension of scale must be equal to "
"Channels([%d]). But received: "
"the first dimension of scale is [%d],"
"the dimensions of scale is [%s], ",
C,
scale_ptr->dims()[0],
scale_ptr->dims()));
}
int r = xpu::instance_norm_grad(
dev_ctx.x_context(),
reinterpret_cast<const XPUType*>(x.data<T>()),
reinterpret_cast<const XPUType*>(y_grad.data<T>()),
reinterpret_cast<XPUType*>(x_grad->data<T>()),
scale_ptr->data<float>(),
saved_mean.data<float>(),
saved_variance.data<float>(),
scale_grad->data<float>(),
bias_grad->data<float>(),
n,
c,
h,
w,
epsilon,
true);
DenseTensor scale_tmp;
int r;
if (!scale_ptr) {
scale_tmp.Resize({C});
dev_ctx.template Alloc<T>(&scale_tmp);
r = xpu::constant(dev_ctx.x_context(),
reinterpret_cast<XPUType*>(scale_tmp.data<T>()),
scale_tmp.numel(),
static_cast<XPUType>(1));
PADDLE_ENFORCE_XDNN_SUCCESS(r, "constant");
}
auto scale_ptr_tmp = scale_ptr ? scale_ptr : &scale_tmp;
xpu::ctx_guard RAII_GUARD(dev_ctx.x_context());
auto d_x_data =
d_x ? d_x->data<T>() : RAII_GUARD.alloc_l3_or_gm<T>(x.numel());
r = xpu::instance_norm_grad(dev_ctx.x_context(),
reinterpret_cast<const XPUType*>(x.data<T>()),
reinterpret_cast<const XPUType*>(d_y.data<T>()),
reinterpret_cast<XPUType*>(d_x_data),
scale_ptr_tmp->data<float>(),
saved_mean.data<float>(),
saved_variance.data<float>(),
d_scale_data,
d_bias_data,
N,
C,
H,
W,
epsilon,
true);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "instance_norm_grad");
}
......
......@@ -58,4 +58,5 @@ PD_REGISTER_KERNEL(scale,
phi::ScaleKernel,
float,
phi::dtype::float16,
int,
int64_t) {}
......@@ -114,4 +114,4 @@ void ScatterKernel(const Context &ctx,
} // namespace phi
PD_REGISTER_KERNEL(
scatter, XPU, ALL_LAYOUT, phi::ScatterKernel, float, int64_t) {}
scatter, XPU, ALL_LAYOUT, phi::ScatterKernel, float, int, int64_t) {}
......@@ -195,7 +195,7 @@ class TestNewCustomOpXpuSetUpInstall(unittest.TestCase):
self.custom_op = custom_relu_xpu_module_setup.custom_relu
self.dtypes = ['float32', 'float64']
self.dtypes = ['float32']
self.device = 'xpu'
# config seed
......
......@@ -18,6 +18,8 @@ import unittest
import numpy as np
import paddle
import paddle.fluid as fluid
from paddle.fluid import Program, program_guard
sys.path.append("..")
from op_test_xpu import XPUOpTest
......@@ -69,6 +71,7 @@ class XPUTestInstanceNormOp(XPUOpTestWrapper):
self.dtype = self.in_type
self.shape = [2, 3, 4, 5]
self.epsilon = 1e-05
self.no_grad_set = None
self.set_attrs()
np.random.seed(12345)
......@@ -112,7 +115,12 @@ class XPUTestInstanceNormOp(XPUOpTestWrapper):
self.check_output_with_place(paddle.XPUPlace(0))
def test_check_grad(self):
self.check_grad_with_place(paddle.XPUPlace(0), ['X'], 'Y')
self.check_grad_with_place(
paddle.XPUPlace(0),
['X', 'Scale', 'Bias'],
['Y'],
self.no_grad_set,
)
class TestXPUInstanceNormOp1(XPUTestInstanceNormOp):
def set_attrs(self):
......@@ -134,6 +142,57 @@ class XPUTestInstanceNormOp(XPUOpTestWrapper):
def set_attrs(self):
self.shape = [10, 3, 512, 1]
class TestXPUInstanceNormOp6(XPUTestInstanceNormOp):
def set_attrs(self):
self.shape = [10, 12, 32, 32]
self.no_grad_set = set(['Scale', 'Bias'])
class TestXPUInstanceNormOp7(XPUTestInstanceNormOp):
def set_attrs(self):
self.shape = [4, 5, 6, 7]
self.no_grad_set = set(['Scale', 'Bias'])
class TestXPUInstanceNormOp8(XPUTestInstanceNormOp):
def set_attrs(self):
self.shape = [1, 8, 16, 16]
self.no_grad_set = set(['Scale', 'Bias'])
class TestXPUInstanceNormOp9(XPUTestInstanceNormOp):
def set_attrs(self):
self.shape = [4, 16, 256, 128]
self.no_grad_set = set(['Scale', 'Bias'])
class TestXPUInstanceNormOp10(XPUTestInstanceNormOp):
def set_attrs(self):
self.shape = [10, 3, 512, 1]
self.no_grad_set = set(['Scale', 'Bias'])
class TestInstanceNormOpError(XPUOpTest):
def setUp(self):
self.__class__.op_type = "instance_norm"
self.__class__.no_need_check_grad = True
self.dtype = self.in_type
def test_errors(self):
with program_guard(Program(), Program()):
# the input of instance_norm must be Variable.
x1 = fluid.create_lod_tensor(
np.array([-1, 3, 5, 5]), [[1, 1, 1, 1]], fluid.XPUPlace(0)
)
self.assertRaises(TypeError, paddle.static.nn.instance_norm, x1)
# the input dtype of instance_norm must be float32
x2 = paddle.static.data(
name='x2', shape=[-1, 3, 4, 5, 6], dtype="int32"
)
self.assertRaises(TypeError, paddle.static.nn.instance_norm, x2)
# the first dimension of input for instance_norm must between [2d, 5d]
x3 = paddle.static.data(name='x', shape=[3], dtype="float32")
self.assertRaises(
ValueError, paddle.static.nn.instance_norm, x3
)
support_types = get_xpu_op_support_types('instance_norm')
for stype in support_types:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册