未验证 提交 216d25ac 编写于 作者: Y ykkk2333 提交者: GitHub

add instance norm op for xpu (#45097)

* xpu unittest grad compute supports more types, *test=kunlun

* add instance norm xpu, *test=kunlun
上级 f4da2d4d
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef PADDLE_WITH_XPU
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/device/device_wrapper.h"
#include "paddle/fluid/platform/device/xpu/xpu_header.h"
#include "paddle/phi/kernels/instance_norm_grad_kernel.h"
#include "paddle/phi/kernels/instance_norm_kernel.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename DeviceContext, typename T>
class InstanceNormXPUKernel : public framework::OpKernel<T> {
using XPUType = typename XPUTypeTrait<T>::Type;
public:
void Compute(const framework::ExecutionContext& ctx) const override {
const auto epsilon = ctx.Attr<float>("epsilon");
const auto* x = ctx.Input<Tensor>("X");
const auto* scale = ctx.Input<Tensor>("Scale");
const auto* bias = ctx.Input<Tensor>("Bias");
auto* y = ctx.Output<Tensor>("Y");
auto* mean = ctx.Output<Tensor>("SavedMean");
auto* variance = ctx.Output<Tensor>("SavedVariance");
auto& dev_ctx = ctx.template device_context<DeviceContext>();
// call phi kernel
phi::InstanceNormKernel<T>(
static_cast<const typename paddle::framework::ConvertToPhiContext<
DeviceContext>::TYPE&>(dev_ctx),
*x,
*scale,
*bias,
epsilon,
y,
mean,
variance);
}
};
template <typename DeviceContext, typename T>
class InstanceNormGradXPUKernel : public framework::OpKernel<T> {
using XPUType = typename XPUTypeTrait<T>::Type;
public:
void Compute(const framework::ExecutionContext& ctx) const override {
const auto epsilon = ctx.Attr<float>("epsilon");
const auto* x = ctx.Input<Tensor>("X");
const auto* mean = ctx.Input<Tensor>("SavedMean");
const auto* variance = ctx.Input<Tensor>("SavedVariance");
const auto* scale = ctx.Input<Tensor>("Scale");
const auto* dy = ctx.Input<Tensor>(framework::GradVarName("Y"));
auto* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
auto* dscale = ctx.Output<Tensor>(framework::GradVarName("Scale"));
auto* dbias = ctx.Output<Tensor>(framework::GradVarName("Bias"));
auto& dev_ctx = ctx.template device_context<DeviceContext>();
// call phi kernel
phi::InstanceNormGradKernel<T>(
static_cast<const typename paddle::framework::ConvertToPhiContext<
DeviceContext>::TYPE&>(dev_ctx),
*x,
*dy,
*scale,
*mean,
*variance,
epsilon,
dx,
dbias,
dscale);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_XPU_KERNEL(
instance_norm,
ops::InstanceNormXPUKernel<paddle::platform::XPUDeviceContext, float>);
REGISTER_OP_XPU_KERNEL(
instance_norm_grad,
ops::InstanceNormGradXPUKernel<paddle::platform::XPUDeviceContext, float>);
#endif // PADDLE_WITH_XPU}
......@@ -289,6 +289,10 @@ XPUOpMap& get_kl2_ops() {
{"huber_loss", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"iou_similarity",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"instance_norm",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"instance_norm_grad",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"label_smooth",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"lars_momentum",
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/instance_norm_grad_kernel.h"
#include "paddle/phi/backends/xpu/enforce_xpu.h"
#include "paddle/phi/core/kernel_registry.h"
namespace phi {
template <typename T, typename Context>
void InstanceNormGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const paddle::optional<DenseTensor>& scale,
const DenseTensor& saved_mean,
const DenseTensor& saved_variance,
const DenseTensor& y_grad,
float epsilon,
DenseTensor* x_grad,
DenseTensor* scale_grad,
DenseTensor* bias_grad) {
using XPUType = typename XPUTypeTrait<T>::Type;
const auto& x_dims = x.dims();
int n = x_dims[0];
int c = x_dims[1];
int h = x_dims[2];
int w = x_dims[3];
dev_ctx.template Alloc<T>(x_grad);
if (bias_grad != nullptr) {
dev_ctx.template Alloc<float>(bias_grad);
}
if (scale_grad != nullptr) {
dev_ctx.template Alloc<float>(scale_grad);
}
const auto scale_ptr = scale.get_ptr();
int r = xpu::instance_norm_grad(
dev_ctx.x_context(),
reinterpret_cast<const XPUType*>(x.data<T>()),
reinterpret_cast<const XPUType*>(y_grad.data<T>()),
reinterpret_cast<XPUType*>(x_grad->data<T>()),
scale_ptr->data<float>(),
saved_mean.data<float>(),
saved_variance.data<float>(),
scale_grad->data<float>(),
bias_grad->data<float>(),
n,
c,
h,
w,
epsilon,
true);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "instance_norm_grad");
}
} // namespace phi
PD_REGISTER_KERNEL(
instance_norm_grad, XPU, ALL_LAYOUT, phi::InstanceNormGradKernel, float) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/instance_norm_kernel.h"
#include "paddle/phi/backends/xpu/enforce_xpu.h"
#include "paddle/phi/core/kernel_registry.h"
namespace phi {
template <typename T, typename Context>
void InstanceNormKernel(const Context& dev_ctx,
const DenseTensor& x,
const paddle::optional<DenseTensor>& scale,
const paddle::optional<DenseTensor>& bias,
float epsilon,
DenseTensor* y,
DenseTensor* saved_mean,
DenseTensor* saved_var) {
using XPUType = typename XPUTypeTrait<T>::Type;
const auto& x_dims = x.dims();
int n = x_dims[0];
int c = x_dims[1];
int h = x_dims[2];
int w = x_dims[3];
dev_ctx.template Alloc<T>(y);
dev_ctx.template Alloc<float>(saved_mean);
dev_ctx.template Alloc<float>(saved_var);
const auto scale_ptr = scale.get_ptr();
const auto bias_ptr = bias.get_ptr();
int r = xpu::instance_norm(dev_ctx.x_context(),
reinterpret_cast<const XPUType*>(x.data<T>()),
reinterpret_cast<XPUType*>(y->data<T>()),
n,
c,
h,
w,
epsilon,
scale_ptr->data<float>(),
bias_ptr->data<float>(),
saved_mean->data<float>(),
saved_var->data<float>(),
true);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "instance_norm");
}
} // namespace phi
PD_REGISTER_KERNEL(
instance_norm, XPU, ALL_LAYOUT, phi::InstanceNormKernel, float) {}
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import numpy as np
import sys
import unittest
from functools import reduce
sys.path.append("..")
from op_test import OpTest
from op_test_xpu import XPUOpTest
from operator import mul
from xpu.get_test_cover_info import create_test_class, get_xpu_op_support_types, XPUOpTestWrapper
paddle.enable_static()
def _reference_instance_norm_naive(x, scale, bias, epsilon, mean, var):
x_shape = x.shape
if len(x_shape) == 2:
x = np.reshape(x, (x.shape[0], x.shape[1], 1, 1))
n, c, h, w = x.shape
mean_tile = np.reshape(mean, (n, c, 1, 1))
mean_tile = np.tile(mean_tile, (1, 1, h, w))
var_tile = np.reshape(var, (n, c, 1, 1))
var_tile = np.tile(var_tile, (1, 1, h, w))
x_norm = (x - mean_tile) / np.sqrt(var_tile + epsilon).astype('float32')
scale_tile = np.reshape(scale, (1, c, 1, 1))
scale_tile = np.tile(scale_tile, (n, 1, h, w))
bias_tile = np.reshape(bias, (1, c, 1, 1))
bias_tile = np.tile(bias_tile, (n, 1, h, w))
y = scale_tile * x_norm + bias_tile
if len(x_shape) == 2:
y = np.reshape(y, x_shape)
return y, mean, var
def _cal_mean_variance(x, epsilon, mean_shape):
mean = np.reshape(np.mean(x, axis=(2, 3)), mean_shape)
var = np.reshape(np.var(x, axis=(2, 3)), mean_shape)
return mean, var
class XPUTestInstanceNormOp(XPUOpTestWrapper):
def __init__(self):
self.op_name = 'instance_norm'
self.use_dynamic_create_class = False
class XPUTestInstanceNormOp(XPUOpTest):
def setUp(self):
self.op_type = "instance_norm"
self.dtype = self.in_type
self.shape = [2, 3, 4, 5]
self.epsilon = 1e-05
self.set_attrs()
np.random.seed(12345)
epsilon = self.epsilon
shape = self.shape
n, c, h, w = shape[0], shape[1], shape[2], shape[3]
scale_shape = [c]
mean_shape = [n * c]
x_np = np.random.random_sample(shape).astype(self.dtype)
scale_np = np.random.random_sample(scale_shape).astype(np.float32)
bias_np = np.random.random_sample(scale_shape).astype(np.float32)
mean, variance = self.set_global_mean_var(mean_shape, x_np)
ref_y_np, ref_saved_mean, variance_tmp = _reference_instance_norm_naive(
x_np, scale_np, bias_np, epsilon, mean, variance)
ref_saved_variance = 1 / np.sqrt(variance_tmp + epsilon)
self.inputs = {'X': x_np, 'Scale': scale_np, 'Bias': bias_np}
self.outputs = {
'Y': ref_y_np,
'SavedMean': ref_saved_mean,
'SavedVariance': ref_saved_variance
}
self.attrs = {'epsilon': epsilon, 'use_xpu': True}
def set_global_mean_var(self, mean_shape, x):
mean, variance = _cal_mean_variance(x, self.epsilon, mean_shape)
return mean, variance
def set_attrs(self):
pass
def test_check_output(self):
self.check_output_with_place(paddle.XPUPlace(0))
def test_check_grad(self):
self.check_grad_with_place(paddle.XPUPlace(0), ['X'], 'Y')
class TestXPUInstanceNormOp1(XPUTestInstanceNormOp):
def set_attrs(self):
self.shape = [10, 12, 32, 32]
class TestXPUInstanceNormOp2(XPUTestInstanceNormOp):
def set_attrs(self):
self.shape = [4, 5, 6, 7]
class TestXPUInstanceNormOp3(XPUTestInstanceNormOp):
def set_attrs(self):
self.shape = [1, 8, 16, 16]
class TestXPUInstanceNormOp4(XPUTestInstanceNormOp):
def set_attrs(self):
self.shape = [4, 16, 256, 128]
class TestXPUInstanceNormOp5(XPUTestInstanceNormOp):
def set_attrs(self):
self.shape = [10, 3, 512, 1]
support_types = get_xpu_op_support_types('instance_norm')
for stype in support_types:
create_test_class(globals(), XPUTestInstanceNormOp, stype)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册