未验证 提交 fb16bdc7 编写于 作者: Y ykkk2333 提交者: GitHub

add xpu cumprod, group norm grad (#52089)

上级 93d01787
......@@ -8,7 +8,7 @@ set(XPU_API_LIB_NAME "libxpuapi.so")
set(XPU_RT_LIB_NAME "libxpurt.so")
set(XPU_XFT_LIB_NAME "libxft.so")
set(XPU_BASE_DATE "20230310")
set(XPU_BASE_DATE "20230323")
set(XPU_XCCL_BASE_VERSION "1.0.13")
set(XPU_XFT_BASE_VERSION "latest")
......
......@@ -160,6 +160,10 @@ XPUOpMap& get_kl2_ops() {
phi::DataType::FLOAT16,
phi::DataType::INT32,
phi::DataType::INT64})},
{"cumprod",
XPUKernelSet({phi::DataType::FLOAT32,
phi::DataType::INT32,
phi::DataType::INT64})},
{"deformable_conv_grad", XPUKernelSet({phi::DataType::FLOAT32})},
{"deformable_conv", XPUKernelSet({phi::DataType::FLOAT32})},
{"deformable_conv_v1_grad", XPUKernelSet({phi::DataType::FLOAT32})},
......@@ -714,6 +718,15 @@ XPUOpMap& get_kl2_ops() {
{"tanh", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
{"temporal_shift", XPUKernelSet({phi::DataType::FLOAT32})},
{"temporal_shift_grad", XPUKernelSet({phi::DataType::FLOAT32})},
{"transfer_dtype",
XPUKernelSet({phi::DataType::FLOAT32,
phi::DataType::FLOAT16,
phi::DataType::FLOAT64,
phi::DataType::BOOL,
phi::DataType::UINT8,
phi::DataType::INT8,
phi::DataType::INT64,
phi::DataType::INT32})},
{"tril_triu",
XPUKernelSet({phi::DataType::FLOAT32,
phi::DataType::INT32,
......@@ -844,6 +857,7 @@ XPUOpMap& get_kl2_ops() {
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::INT32})},
{"randint", XPUKernelSet({phi::DataType::INT32, phi::DataType::INT64})},
{"group_norm", XPUKernelSet({phi::DataType::FLOAT32})},
{"group_norm_grad", XPUKernelSet({phi::DataType::FLOAT32})},
{"meshgrid",
XPUKernelSet({phi::DataType::FLOAT32,
phi::DataType::INT32,
......
......@@ -62,7 +62,7 @@ struct XPUContext::Impl {
std::string cur_thread_name = phi::GetCurrentThreadName();
VLOG(3) << "XPU Dataloader: current thread at Get Context = "
<< phi::GetCurrentThreadName();
bool is_dataloader_thread = (cur_thread_name.substr(0, 10) == "Dataloader");
bool is_dataloader_thread = (cur_thread_name != "MainThread");
return is_dataloader_thread;
}
......@@ -146,6 +146,11 @@ struct XPUContext::Impl {
backends::xpu::XPUDeviceGuard guard(place_.GetDeviceId());
PD_CHECK(context_ != nullptr, "the xpu context is nullptr.");
xpu_wait(context_->xpu_stream);
xpu::Context* ctx_t = GetXdlCtx();
if (ctx_t) {
PD_CHECK(ctx_t != nullptr, "the xpu dataloader context is nullptr.");
xpu_wait(ctx_t->xpu_stream);
}
}
void Init() {
......
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/cumprod_kernel.h"
#include "paddle/phi/backends/xpu/enforce_xpu.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/complex_functors.h"
#include "paddle/phi/kernels/funcs/cumprod.h"
namespace phi {
template <typename T, typename Context>
void CumprodKernel(const Context& dev_ctx,
const DenseTensor& input,
int dim,
DenseTensor* out) {
using XPUType = typename XPUTypeTrait<T>::Type;
const DenseTensor* x = &input;
auto* x_data = x->data<T>();
auto* out_data = dev_ctx.template Alloc<T>(out);
DDim shape = x->dims();
std::vector<int64_t> xshape = phi::vectorize<int64_t>(shape);
if (dim < 0) dim += xshape.size();
if (shape.size() == 0) {
int r =
xpu::copy<XPUType>(dev_ctx.x_context(),
reinterpret_cast<const XPUType*>(input.data<T>()),
reinterpret_cast<XPUType*>(out->data<T>()),
input.numel());
PADDLE_ENFORCE_XDNN_SUCCESS(r, "copy");
return;
}
int r = xpu::cumprod(dev_ctx.x_context(),
reinterpret_cast<const XPUType*>(x_data),
reinterpret_cast<XPUType*>(out_data),
xshape,
dim);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "cumprod");
}
} // namespace phi
PD_REGISTER_KERNEL(
cumprod, XPU, ALL_LAYOUT, phi::CumprodKernel, float, int, int64_t) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/group_norm_grad_kernel.h"
#include <algorithm>
#include <array>
#include <numeric>
#include <string>
#include "paddle/phi/backends/xpu/enforce_xpu.h"
#include "paddle/phi/common/layout.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/math_function.h"
namespace phi {
template <typename T, typename Context>
void GroupNormGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const paddle::optional<DenseTensor>& scale,
const paddle::optional<DenseTensor>& bias,
const DenseTensor& y,
const DenseTensor& mean,
const DenseTensor& var,
const DenseTensor& d_y,
float epsilon,
int groups,
const std::string& data_layout_str,
DenseTensor* d_x,
DenseTensor* d_scale,
DenseTensor* d_bias) {
using XPUType = typename XPUTypeTrait<T>::Type;
const DataLayout data_layout = phi::StringToDataLayout(data_layout_str);
const auto scale_ptr = scale.get_ptr();
const auto bias_ptr = bias.get_ptr();
const auto x_dims = phi::vectorize<int>(x.dims());
const int N = x_dims[0];
const bool channel_first =
data_layout == DataLayout::kNCHW || data_layout == DataLayout::kNCDHW;
const int C = (channel_first ? x_dims[1] : x_dims[x_dims.size() - 1]);
const int L =
(channel_first
? std::accumulate(
x_dims.begin() + 2, x_dims.end(), 1, std::multiplies<int>())
: std::accumulate(x_dims.begin() + 1,
x_dims.end() - 1,
1,
std::multiplies<int>()));
dev_ctx.template Alloc<T>(d_x);
phi::funcs::SetConstant<XPUContext, T> set_zero;
auto* x_data = x.data<T>();
auto* y_data = y.data<T>();
auto* d_x_data = d_x->data<T>();
auto* d_y_data = d_y.data<T>();
auto* mean_data = mean.data<T>();
auto* var_data = var.data<T>();
T* d_scale_data = nullptr;
if (d_scale) {
dev_ctx.template Alloc<T>(d_scale);
set_zero(dev_ctx, d_scale, static_cast<T>(0));
d_scale_data = d_scale->data<T>();
}
T* d_bias_data = nullptr;
if (d_bias) {
dev_ctx.template Alloc<T>(d_bias);
set_zero(dev_ctx, d_bias, static_cast<T>(0));
d_bias_data = d_bias->data<T>();
}
const T* scale_data = nullptr;
if (scale_ptr) scale_data = scale_ptr->data<T>();
const T* bias_data = nullptr;
if (bias_ptr) bias_data = bias_ptr->data<T>();
int r = xpu::group_norm_grad<XPUType>(
dev_ctx.x_context(),
reinterpret_cast<const XPUType*>(x_data),
reinterpret_cast<const XPUType*>(y_data),
reinterpret_cast<const XPUType*>(d_y_data),
reinterpret_cast<XPUType*>(d_x_data),
N,
C,
L,
1,
groups,
static_cast<XPUType>(epsilon),
reinterpret_cast<const XPUType*>(scale_data),
reinterpret_cast<const XPUType*>(bias_data),
reinterpret_cast<const XPUType*>(mean_data),
reinterpret_cast<const XPUType*>(var_data),
reinterpret_cast<XPUType*>(d_scale_data),
reinterpret_cast<XPUType*>(d_bias_data),
channel_first);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "group_norm_grad");
}
} // namespace phi
PD_REGISTER_KERNEL(
group_norm_grad, XPU, ALL_LAYOUT, phi::GroupNormGradKernel, float) {}
......@@ -116,8 +116,7 @@ void Pool2dGradKernel(const Context& ctx,
// and broadcast kernels to get same output, but better performance.
// Since the dim is special in particular models,
// use 'export XPU_POOLING_GRAD_SPECIAL=1' to open this path
if (out_h == 1 && out_w == 1 && std::is_same<T, float>::value &&
std::getenv("XPU_POOLING_GRAD_SPECIAL") != nullptr) {
if (out_h == 1 && out_w == 1 && std::is_same<T, float>::value) {
xpu::ctx_guard RAII_GUARD(ctx.x_context());
float scale = 1.0 / (in_h * in_w);
float* scaled_dy = RAII_GUARD.alloc_l3_or_gm<float>(n * c);
......@@ -301,8 +300,7 @@ void Pool3dGradKernel(const Context& ctx,
} else if (pooling_type == "avg") {
if (out_d == 1 && out_h == 1 && out_w == 1 &&
std::is_same<T, float>::value &&
std::getenv("XPU_POOLING_GRAD_SPECIAL") != nullptr) {
std::is_same<T, float>::value) {
xpu::ctx_guard RAII_GUARD(ctx.x_context());
float scale = 1.0 / (in_d * in_h * in_w);
float* scaled_dy = RAII_GUARD.alloc_l3_or_gm<float>(n * c);
......
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import random
import sys
import unittest
import numpy as np
sys.path.append("..")
from op_test_xpu import XPUOpTest
from xpu.get_test_cover_info import (
XPUOpTestWrapper,
create_test_class,
get_xpu_op_support_types,
)
import paddle
np.random.seed(0)
# define cumprod grad function.
def cumprod_grad(x, y, dy, dx, shape, dim):
if dim < 0:
dim += len(shape)
mid_dim = shape[dim]
outer_dim = 1
inner_dim = 1
for i in range(0, dim):
outer_dim *= shape[i]
for i in range(dim + 1, len(shape)):
inner_dim *= shape[i]
for i in range(outer_dim):
for k in range(inner_dim):
for j in range(mid_dim):
index = i * mid_dim * inner_dim + j * inner_dim + k
for n in range(mid_dim):
pos = i * mid_dim * inner_dim + n * inner_dim + k
elem = 0
if j == 0:
elem = dy[pos]
else:
elem = dy[pos] * y[index - inner_dim]
if pos > index:
for m in range(
index + inner_dim, pos + inner_dim, inner_dim
):
elem *= x[m]
elif pos < index:
elem = 0
dx[index] += elem
# test function.
class XPUTestCumprodOP(XPUOpTestWrapper):
def __init__(self):
self.op_name = 'cumprod'
self.use_dynamic_create_class = False
class TestCumprod(XPUOpTest):
def init_params(self):
self.shape = (2, 3, 4, 5)
self.zero_nums = [0, 10, 20, 30, int(np.prod(self.shape))]
def init_dtype(self):
self.dtype = self.in_type
def setUp(self):
paddle.enable_static()
self.place = paddle.XPUPlace(0)
self.init_params()
self.init_dtype()
self.op_type = "cumprod"
self.python_api = paddle.cumprod
self.inputs = {'X': None}
self.outputs = {'Out': None}
self.attrs = {'dim': None}
def prepare_inputs_outputs_attrs(self, dim, zero_num):
self.x = np.random.random(self.shape).astype(self.dtype) + 0.5
if zero_num > 0:
zero_num = min(zero_num, self.x.size)
shape = self.x.shape
self.x = self.x.flatten()
indices = random.sample(range(self.x.size), zero_num)
for i in indices:
self.x[i] = 0
self.x = np.reshape(self.x, self.shape)
self.out = np.cumprod(self.x, axis=dim)
self.inputs = {'X': self.x}
self.outputs = {'Out': self.out}
self.attrs = {'dim': dim}
def init_grad_input_output(self, dim):
reshape_x = self.x.reshape(self.x.size)
self.grad_out = np.ones(self.x.size, self.dtype)
self.grad_x = np.zeros(self.x.size, self.dtype)
out_data = self.out.reshape(self.x.size)
if self.dtype == np.complex128 or self.dtype == np.complex64:
reshape_x = np.conj(reshape_x)
out_data = np.conj(out_data)
cumprod_grad(
reshape_x, out_data, self.grad_out, self.grad_x, self.shape, dim
)
self.grad_x = self.grad_x.reshape(self.shape)
self.grad_out = self.grad_out.reshape(self.shape)
# test forward.
def test_check_output(self):
for dim in range(-len(self.shape), len(self.shape)):
for zero_num in self.zero_nums:
self.prepare_inputs_outputs_attrs(dim, zero_num)
self.check_output_with_place(self.place)
# test backward.
def test_check_grad(self):
pass
# test api.
class TestCumprodAPI(unittest.TestCase):
def init_dtype(self):
self.dtype = 'float32'
self.shape = [2, 3, 10, 10]
def setUp(self):
paddle.enable_static()
self.init_dtype()
self.x = (np.random.rand(2, 3, 10, 10) + 0.5).astype(self.dtype)
self.place = [paddle.XPUPlace(0)]
# test static graph api.
def test_static_api(self):
paddle.enable_static()
def run(place):
with paddle.static.program_guard(paddle.static.Program()):
x = paddle.static.data('X', self.shape, dtype=self.dtype)
out = paddle.cumprod(x, -2)
exe = paddle.static.Executor(place)
res = exe.run(feed={'X': self.x}, fetch_list=[out])
out_ref = np.cumprod(self.x, -2)
for r in res:
np.testing.assert_allclose(out_ref, r, rtol=1e-05)
for place in self.place:
run(place)
# test dynamic graph api.
def test_dygraph_api(self):
def run(place):
paddle.disable_static(place)
x = paddle.to_tensor(self.x)
out = paddle.cumprod(x, 1)
out_ref = np.cumprod(self.x, 1)
np.testing.assert_allclose(out_ref, out.numpy(), rtol=1e-05)
paddle.enable_static()
for place in self.place:
run(place)
support_types = get_xpu_op_support_types('cumprod')
for stype in support_types:
create_test_class(globals(), XPUTestCumprodOP, stype)
if __name__ == "__main__":
unittest.main()
......@@ -92,7 +92,9 @@ class XPUTestGroupNormOp(XPUOpTestWrapper):
self.check_output_with_place(paddle.XPUPlace(0))
def test_check_grad(self):
pass
self.check_grad_with_place(
paddle.XPUPlace(0), ['X', 'Scale', 'Bias'], 'Y'
)
class TestGroupNormOp2(TestGroupNormOp):
def init_test_case(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册