未验证 提交 c86a5140 编写于 作者: R ronnywang 提交者: GitHub

[XPU] add group_norm, sin, cos, linspace, randint kernels (#50465)

* [XPU] add group_norm kernel

* update

* add xpu sin, cos, randint, linspace kernels

* update

* update
上级 8910bb4a
......@@ -711,6 +711,12 @@ XPUOpMap& get_kl2_ops() {
XPUKernelSet({phi::DataType::INT32,
phi::DataType::INT64,
phi::DataType::FLOAT32})},
{"sin", XPUKernelSet({phi::DataType::FLOAT32})},
{"cos", XPUKernelSet({phi::DataType::FLOAT32})},
{"linspace",
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::INT32})},
{"randint", XPUKernelSet({phi::DataType::INT32, phi::DataType::INT64})},
{"group_norm", XPUKernelSet({phi::DataType::FLOAT32})},
// AddMore
{"sequence_conv", XPUKernelSet({phi::DataType::FLOAT32})},
......
......@@ -457,6 +457,32 @@ struct XPUFloorFunctor : public funcs::BaseActivationFunctor<T> {
}
};
template <typename T>
struct XPUSinFunctor : public funcs::BaseActivationFunctor<T> {
using XPUType = typename XPUTypeTrait<T>::Type;
template <typename Context>
void operator()(const Context& dev_ctx,
const DenseTensor& x,
DenseTensor* out) const {
int r = xpu_activation_func<Context, T, XPUType>(
dev_ctx, x, out, xpu::sin<XPUType>);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "sin");
}
};
template <typename T>
struct XPUCosFunctor : public funcs::BaseActivationFunctor<T> {
using XPUType = typename XPUTypeTrait<T>::Type;
template <typename Context>
void operator()(const Context& dev_ctx,
const DenseTensor& x,
DenseTensor* out) const {
int r = xpu_activation_func<Context, T, XPUType>(
dev_ctx, x, out, xpu::cos<XPUType>);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "cos");
}
};
DEFINE_XPU_ACTIVATION_KERNEL(Exp, XPUExpFunctor)
DEFINE_XPU_ACTIVATION_KERNEL(Floor, XPUFloorFunctor)
DEFINE_XPU_ACTIVATION_KERNEL(Log, XPULogFunctor)
......@@ -467,6 +493,8 @@ DEFINE_XPU_ACTIVATION_KERNEL(Square, XPUSquareFunctor)
DEFINE_XPU_ACTIVATION_KERNEL(Sqrt, XPUSqrtFunctor)
DEFINE_XPU_ACTIVATION_KERNEL(Tanh, XPUTanhFunctor)
DEFINE_XPU_ACTIVATION_KERNEL(Silu, XPUSiluFunctor)
DEFINE_XPU_ACTIVATION_KERNEL(Sin, XPUSinFunctor)
DEFINE_XPU_ACTIVATION_KERNEL(Cos, XPUCosFunctor)
DEFINE_XPU_ACTIVATION_KERNEL_WITH_ONE_ATTRS(Mish, XPUMishFunctor, threshold)
DEFINE_XPU_ACTIVATION_KERNEL_WITH_ONE_ATTRS(LeakyRelu,
......@@ -531,3 +559,5 @@ PD_REGISTER_ACTIVATION_KERNEL(sigmoid, SigmoidKernel)
PD_REGISTER_ACTIVATION_KERNEL(sqrt, SqrtKernel)
PD_REGISTER_ACTIVATION_KERNEL(swish_raw, SwishRawKernel)
PD_REGISTER_ACTIVATION_KERNEL(softplus, SoftplusKernel)
PD_REGISTER_ACTIVATION_KERNEL(sin, SinKernel)
PD_REGISTER_ACTIVATION_KERNEL(cos, CosKernel)
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/group_norm_kernel.h"
#include <algorithm>
#include <array>
#include <numeric>
#include <string>
#include "paddle/phi/backends/xpu/enforce_xpu.h"
#include "paddle/phi/common/layout.h"
#include "paddle/phi/core/kernel_registry.h"
namespace phi {
template <typename T, typename Context>
void GroupNormKernel(const Context& dev_ctx,
const DenseTensor& x,
const paddle::optional<DenseTensor>& scale,
const paddle::optional<DenseTensor>& bias,
float epsilon,
int groups,
const std::string& data_layout_str,
DenseTensor* y,
DenseTensor* mean,
DenseTensor* var) {
using XPUType = typename XPUTypeTrait<T>::Type;
const DataLayout data_layout = phi::StringToDataLayout(data_layout_str);
const auto scale_ptr = scale.get_ptr();
const auto bias_ptr = bias.get_ptr();
const auto x_dims = phi::vectorize<int>(x.dims());
const int N = x_dims[0];
const bool channel_first =
data_layout == DataLayout::kNCHW || data_layout == DataLayout::kNCDHW;
const int C = (channel_first ? x_dims[1] : x_dims[x_dims.size() - 1]);
const int L =
(channel_first
? std::accumulate(
x_dims.begin() + 2, x_dims.end(), 1, std::multiplies<int>())
: std::accumulate(x_dims.begin() + 1,
x_dims.end() - 1,
1,
std::multiplies<int>()));
dev_ctx.template Alloc<T>(y);
dev_ctx.template Alloc<T>(mean);
dev_ctx.template Alloc<T>(var);
auto* x_data = x.data<T>();
auto* y_data = y->data<T>();
auto* mean_data = mean->data<T>();
auto* var_data = var->data<T>();
const T* scale_data = nullptr;
if (scale_ptr) scale_data = scale_ptr->data<T>();
const T* bias_data = nullptr;
if (bias_ptr) bias_data = bias_ptr->data<T>();
auto r =
xpu::group_norm<XPUType>(dev_ctx.x_context(),
reinterpret_cast<const XPUType*>(x_data),
reinterpret_cast<XPUType*>(y_data),
N,
C,
L,
1,
groups,
static_cast<XPUType>(epsilon),
reinterpret_cast<const XPUType*>(scale_data),
reinterpret_cast<const XPUType*>(bias_data),
reinterpret_cast<XPUType*>(mean_data),
reinterpret_cast<XPUType*>(var_data),
channel_first);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "group_norm");
}
} // namespace phi
PD_REGISTER_KERNEL(group_norm, XPU, ALL_LAYOUT, phi::GroupNormKernel, float) {}
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/linspace_kernel.h"
#include "paddle/phi/backends/xpu/enforce_xpu.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/data_type_transform.h"
namespace phi {
template <typename T, typename Context>
T GetValueOfExpectedType(const Context& ctx, const DenseTensor& x) {
switch (x.dtype()) {
case DataType::FLOAT32:
return static_cast<T>(GetValue<float, Context>(ctx, x));
case DataType::FLOAT64:
return static_cast<T>(GetValue<double, Context>(ctx, x));
case DataType::INT32:
return static_cast<T>(GetValue<int32_t, Context>(ctx, x));
case DataType::INT64:
return static_cast<T>(GetValue<int64_t, Context>(ctx, x));
case DataType::FLOAT16:
return static_cast<T>(GetValue<phi::dtype::float16, Context>(ctx, x));
case DataType::BFLOAT16:
return static_cast<T>(GetValue<phi::dtype::bfloat16, Context>(ctx, x));
case DataType::BOOL:
return static_cast<T>(GetValue<bool, Context>(ctx, x));
case DataType::INT16:
return static_cast<T>(GetValue<int16_t, Context>(ctx, x));
case DataType::UINT8:
return static_cast<T>(GetValue<uint8_t, Context>(ctx, x));
default:
PADDLE_THROW(phi::errors::Unimplemented(
"Data type (%s) is not supported when casting data type.",
x.dtype()));
}
}
template <typename T, typename Context>
void LinspaceKernel(const Context& ctx,
const DenseTensor& start,
const DenseTensor& stop,
const DenseTensor& number,
DataType dtype,
DenseTensor* out) {
using XPUType = typename XPUTypeTrait<T>::Type;
T start_value = GetValueOfExpectedType<T, Context>(ctx, start);
T stop_value = GetValueOfExpectedType<T, Context>(ctx, stop);
int32_t num = GetValueOfExpectedType<int32_t, Context>(ctx, number);
PADDLE_ENFORCE_GT(
num,
0,
phi::errors::InvalidArgument("The num of linspace op should be larger "
"than 0, but received num is %d",
num));
out->Resize(phi::make_ddim({num}));
T* out_data = ctx.template Alloc<T>(out);
int r = xpu::linspace(ctx.x_context(),
reinterpret_cast<XPUType*>(out_data),
static_cast<XPUType>(start_value),
static_cast<XPUType>(stop_value),
num);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "linspace");
}
} // namespace phi
PD_REGISTER_KERNEL(
linspace, XPU, ALL_LAYOUT, phi::LinspaceKernel, float, int32_t) {}
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/randint_kernel.h"
#include <random>
#include "paddle/fluid/memory/memcpy.h"
#include "paddle/phi/backends/xpu/enforce_xpu.h"
#include "paddle/phi/core/generator.h"
#include "paddle/phi/core/kernel_registry.h"
namespace phi {
template <typename T, typename Context>
void RandintRawKernel(const Context& dev_ctx,
int low,
int high,
const IntArray& shape,
DataType dtype,
int seed,
DenseTensor* out) {
int64_t size = out->numel();
out->Resize(phi::make_ddim(shape.GetData()));
T* data = dev_ctx.template Alloc<T>(out);
auto numel = out->numel();
std::shared_ptr<std::mt19937_64> engine;
if (seed) {
engine = std::make_shared<std::mt19937_64>();
engine->seed(seed);
} else {
engine = dev_ctx.GetGenerator()->GetCPUEngine();
}
std::unique_ptr<T[]> data_cpu(new T[size]);
std::uniform_int_distribution<T> dist(low, high - 1);
for (int64_t i = 0; i < numel; ++i) {
data_cpu[i] = dist(*engine);
}
paddle::memory::Copy(dev_ctx.GetPlace(),
data,
phi::CPUPlace(),
reinterpret_cast<void*>(data_cpu.get()),
size * sizeof(T));
}
template <typename T, typename Context>
void RandintKernel(const Context& dev_ctx,
int low,
int high,
const IntArray& shape,
DataType dtype,
DenseTensor* out) {
RandintRawKernel<T>(dev_ctx, low, high, shape, dtype, 0, out);
}
} // namespace phi
PD_REGISTER_KERNEL(
randint_raw, XPU, ALL_LAYOUT, phi::RandintRawKernel, int, int64_t) {}
PD_REGISTER_KERNEL(randint, XPU, ALL_LAYOUT, phi::RandintKernel, int, int64_t) {
}
......@@ -1222,5 +1222,104 @@ def ref_mish(x, threshold=20):
return out
class XPUTestSinOP(XPUOpTestWrapper):
def __init__(self):
self.op_name = 'sin'
self.use_dynamic_create_class = False
class XPUTestSinBase(TestActivationOPBase):
def set_case(self):
self.op_type = "sin"
self.dtype = self.in_type
self.init_config()
out = np.sin(self.x)
self.inputs = {'X': self.x}
self.outputs = {'Out': out}
self.attrs = {'use_xpu': True}
def init_config(self):
self.x = np.random.uniform(-np.pi, np.pi, [11, 17]).astype(
self.dtype
)
class XPUTestSin_ZeroDim(XPUTestSinBase):
def init_config(self):
self.x = np.random.uniform(-np.pi, np.pi, []).astype(self.dtype)
class XPUTestSin2(XPUTestSinBase):
def init_config(self):
self.x = np.random.uniform(-np.pi, np.pi, [1024, 8]).astype(
self.dtype
)
class XPUTestSin3(XPUTestSinBase):
def init_config(self):
self.x = np.random.uniform(-np.pi, np.pi, [4, 512, 15, 15]).astype(
self.dtype
)
class XPUTestSin4(XPUTestSinBase):
def init_config(self):
self.x = np.random.uniform(-np.pi, np.pi, [4, 256, 22, 22]).astype(
self.dtype
)
support_types = get_xpu_op_support_types('sin')
for stype in support_types:
create_test_class(globals(), XPUTestSinOP, stype)
class XPUTestCosOP(XPUOpTestWrapper):
def __init__(self):
self.op_name = 'cos'
self.use_dynamic_create_class = False
class XPUTestCosBase(TestActivationOPBase):
def set_case(self):
self.op_type = "cos"
self.dtype = self.in_type
self.init_config()
out = np.cos(self.x)
self.inputs = {'X': self.x}
self.outputs = {'Out': out}
self.attrs = {'use_xpu': True}
def init_config(self):
self.x = np.random.uniform(-np.pi, np.pi, [11, 17]).astype(
self.dtype
)
class XPUTestCos_ZeroDim(XPUTestCosBase):
def init_config(self):
self.x = np.random.uniform(-np.pi, np.pi, []).astype(self.dtype)
class XPUTestCos2(XPUTestCosBase):
def init_config(self):
self.x = np.random.uniform(-np.pi, np.pi, [1024, 8]).astype(
self.dtype
)
class XPUTestCos3(XPUTestCosBase):
def init_config(self):
self.x = np.random.uniform(-np.pi, np.pi, [4, 512, 15, 15]).astype(
self.dtype
)
class XPUTestCos4(XPUTestCosBase):
def init_config(self):
self.x = np.random.uniform(-np.pi, np.pi, [4, 256, 22, 22]).astype(
self.dtype
)
support_types = get_xpu_op_support_types('cos')
for stype in support_types:
create_test_class(globals(), XPUTestCosOP, stype)
if __name__ == "__main__":
unittest.main()
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import unittest
import numpy as np
sys.path.append("..")
from op_test import OpTest
from op_test_xpu import XPUOpTest
from xpu.get_test_cover_info import (
XPUOpTestWrapper,
create_test_class,
get_xpu_op_support_types,
)
import paddle
paddle.enable_static()
def group_norm_naive(x, scale, bias, epsilon, groups, data_layout):
if data_layout == "NHWC":
x = np.transpose(x, (0, 3, 1, 2)) # NHWC => NCHW
N, C, H, W = x.shape
G = groups
x = x.reshape((N * G, -1))
mean = np.mean(x, axis=1, keepdims=True)
var = np.var(x, axis=1, keepdims=True)
output = (x - mean) / np.sqrt(var + epsilon)
output = output.reshape((N, C, H, W)) * scale.reshape(
(-1, 1, 1)
) + bias.reshape((-1, 1, 1))
if data_layout == "NHWC":
output = np.transpose(output, (0, 2, 3, 1)) # NCHW => NHWC
return output, mean.reshape((N, G)), var.reshape((N, G))
class XPUTestGroupNormOp(XPUOpTestWrapper):
def __init__(self):
self.op_name = 'group_norm'
self.use_dynamic_create_class = False
class TestGroupNormOp(XPUOpTest):
def init_test_case(self):
self.data_format = "NCHW"
self.attrs = {'epsilon': 1e-5, 'groups': 2, 'data_layout': "NCHW"}
def setUp(self):
'''Test GroupNorm Op with supplied attributes'''
self.__class__.op_type = 'group_norm'
self.dtype = self.in_type
self.shape = (2, 100, 3, 5)
self.init_test_case()
input = np.random.random(self.shape).astype(self.dtype)
if self.data_format == "NHWC":
input = np.transpose(input, (0, 2, 3, 1))
scale = np.random.random([self.shape[1]]).astype(self.dtype)
bias = np.random.random([self.shape[1]]).astype(self.dtype)
output, mean, var = group_norm_naive(
input,
scale,
bias,
self.attrs['epsilon'],
self.attrs['groups'],
self.data_format,
)
self.inputs = {
'X': OpTest.np_dtype_to_fluid_dtype(input),
'Scale': OpTest.np_dtype_to_fluid_dtype(scale),
'Bias': OpTest.np_dtype_to_fluid_dtype(bias),
}
self.outputs = {'Y': output, 'Mean': mean, 'Variance': var}
self.attrs['data_layout'] = self.data_format
def test_check_output(self):
self.check_output_with_place(paddle.XPUPlace(0))
def test_check_grad(self):
pass
class TestGroupNormOp2(TestGroupNormOp):
def init_test_case(self):
self.data_format = "NHWC"
self.attrs = {'epsilon': 1e-5, 'groups': 2, 'data_layout': "NHWC"}
support_types = get_xpu_op_support_types('group_norm')
for stype in support_types:
create_test_class(globals(), XPUTestGroupNormOp, stype)
if __name__ == "__main__":
paddle.enable_static()
unittest.main()
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import unittest
import numpy as np
sys.path.append("..")
from op_test_xpu import XPUOpTest, convert_np_dtype_to_dtype_
from xpu.get_test_cover_info import (
XPUOpTestWrapper,
create_test_class,
get_xpu_op_support_types,
)
import paddle
paddle.enable_static()
class XPUTestLinspaceOp(XPUOpTestWrapper):
def __init__(self):
self.op_name = 'linspace'
self.use_dynamic_create_class = False
class TestXPULinespaceOp(XPUOpTest):
def setUp(self):
self.op_type = "linspace"
self.dtype = self.in_type
self.set_attrs()
self.atol = 1e-4
np.random.seed(10)
self.inputs = {
'Start': np.array([0]).astype(self.dtype),
'Stop': np.array([10]).astype(self.dtype),
'Num': np.array([11]).astype('int32'),
}
self.outputs = {'Out': np.arange(0, 11).astype(self.dtype)}
self.attrs = {'dtype': int(convert_np_dtype_to_dtype_(self.dtype))}
def set_attrs(self):
pass
def test_check_output(self):
self.check_output_with_place(paddle.XPUPlace(0), atol=self.atol)
class TestXPULinespace2(TestXPULinespaceOp):
def set_attrs(self):
self.inputs = {
'Start': np.array([10]).astype(self.dtype),
'Stop': np.array([0]).astype(self.dtype),
'Num': np.array([11]).astype('int32'),
}
self.outputs = {'Out': np.arange(10, -1, -1).astype(self.dtype)}
class TestXPULinespace3(TestXPULinespaceOp):
def set_attrs(self):
self.inputs = {
'Start': np.array([10]).astype(self.dtype),
'Stop': np.array([0]).astype(self.dtype),
'Num': np.array([1]).astype('int32'),
}
self.outputs = {'Out': np.array(10, dtype=self.dtype)}
support_types = get_xpu_op_support_types('linspace')
for stype in support_types:
create_test_class(globals(), XPUTestLinspaceOp, stype)
if __name__ == "__main__":
unittest.main()
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import unittest
import numpy as np
sys.path.append("..")
from op_test_xpu import XPUOpTest
from xpu.get_test_cover_info import (
XPUOpTestWrapper,
create_test_class,
get_xpu_op_support_types,
)
import paddle
paddle.enable_static()
def output_hist(out):
hist, _ = np.histogram(out, range=(-10, 10))
hist = hist.astype("float32")
hist /= float(out.size)
prob = 0.1 * np.ones((10))
return hist, prob
class XPUTestRandIntOp(XPUOpTestWrapper):
def __init__(self):
self.op_name = 'randint'
self.use_dynamic_create_class = False
class TestXPURandIntOp(XPUOpTest):
def setUp(self):
self.op_type = "randint"
self.dtype = self.in_type
self.set_attrs()
self.atol = 1e-4
np.random.seed(10)
self.inputs = {}
self.outputs = {"Out": np.zeros((10000, 784)).astype("float32")}
self.attrs = {
"shape": [10000, 784],
"low": -10,
"high": 10,
"seed": 10,
}
self.output_hist = output_hist
def set_attrs(self):
pass
def test_check_output(self):
self.check_output_customized(self.verify_output)
def verify_output(self, outs):
hist, prob = self.output_hist(np.array(outs[0]))
np.testing.assert_allclose(hist, prob, rtol=0, atol=0.001)
support_types = get_xpu_op_support_types('randint')
for stype in support_types:
create_test_class(globals(), XPUTestRandIntOp, stype)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册