未验证 提交 7271de88 编写于 作者: R ronnywang 提交者: GitHub

[XPU] add expand_grad, isnan, meshgrid kernels (#50774)

* [XPU] add expand_grad, isnan, meshgrid kernels

* update
上级 612d5da0
......@@ -773,6 +773,14 @@ XPUOpMap& get_kl2_ops() {
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::INT32})},
{"randint", XPUKernelSet({phi::DataType::INT32, phi::DataType::INT64})},
{"group_norm", XPUKernelSet({phi::DataType::FLOAT32})},
{"meshgrid",
XPUKernelSet({phi::DataType::FLOAT32,
phi::DataType::INT32,
phi::DataType::INT64})},
{"expand_v2_grad",
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::INT32})},
{"isnan_v2",
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
// AddMore
{"sequence_conv", XPUKernelSet({phi::DataType::FLOAT32})},
......
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/expand_grad_kernel.h"
#include "paddle/phi/backends/xpu/enforce_xpu.h"
#include "paddle/phi/backends/xpu/xpu_context.h"
#include "paddle/phi/common/scalar.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/kernel_registry.h"
namespace phi {
template <typename T, typename Context>
void ExpandGradKernel(const Context& ctx,
const DenseTensor& x,
const DenseTensor& out_grad,
const IntArray& shape,
DenseTensor* in_grad) {
using XPUType = typename XPUTypeTrait<T>::Type;
auto in_grad_data = ctx.template Alloc<T>(in_grad);
auto out_grad_dims = phi::vectorize<int64_t>(out_grad.dims());
auto in_grad_dims = phi::vectorize<int64_t>(in_grad->dims());
in_grad_dims.insert(
in_grad_dims.begin(), out_grad.dims().size() - in_grad->dims().size(), 1);
int r = xpu::expand_grad<XPUType>(
ctx.x_context(),
reinterpret_cast<const XPUType*>(out_grad.data<T>()),
reinterpret_cast<XPUType*>(in_grad_data),
out_grad_dims,
in_grad_dims);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "expand_grad");
}
} // namespace phi
PD_REGISTER_KERNEL(expand_grad, XPU, ALL_LAYOUT, phi::ExpandGradKernel, float) {
}
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/isfinite_kernel.h"
#include "paddle/phi/backends/xpu/enforce_xpu.h"
#include "paddle/phi/backends/xpu/xpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
namespace phi {
template <typename T, typename Context>
void IsnanKernel(const Context& ctx, const DenseTensor& x, DenseTensor* out) {
using XPUType = typename XPUTypeTrait<T>::Type;
auto* out_data = ctx.template Alloc<bool>(out);
int r = xpu::isnan<XPUType>(ctx.x_context(),
reinterpret_cast<const XPUType*>(x.data<T>()),
out_data,
x.numel());
PADDLE_ENFORCE_XDNN_SUCCESS(r, "isnan");
}
} // namespace phi
PD_REGISTER_KERNEL(
isnan, XPU, ALL_LAYOUT, phi::IsnanKernel, float, phi::dtype::float16) {}
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/meshgrid_kernel.h"
#include "paddle/phi/backends/xpu/enforce_xpu.h"
#include "paddle/phi/backends/xpu/xpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
namespace phi {
template <typename T, typename Context>
void MeshgridKernel(const Context& ctx,
const std::vector<const DenseTensor*>& inputs,
std::vector<DenseTensor*> outputs) {
using XPUType = typename XPUTypeTrait<T>::Type;
std::vector<const XPUType*> x_list;
std::vector<XPUType*> y_list;
std::vector<std::vector<int64_t>> xshape_list;
for (const auto& x : inputs) {
x_list.push_back(reinterpret_cast<const XPUType*>(x->data<T>()));
xshape_list.emplace_back(phi::vectorize<int64_t>(x->dims()));
}
for (auto& x : outputs) {
ctx.template Alloc<T>(x);
y_list.push_back(reinterpret_cast<XPUType*>(x->data<T>()));
}
int r = xpu::meshgrid<XPUType>(ctx.x_context(), x_list, y_list, xshape_list);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "meshgrid");
}
} // namespace phi
PD_REGISTER_KERNEL(
meshgrid, XPU, ALL_LAYOUT, phi::MeshgridKernel, float, int, int64_t) {}
......@@ -68,6 +68,9 @@ class XPUTestExpandV2Op(XPUOpTestWrapper):
def test_check_output(self):
self.check_output_with_place(self.place)
def test_check_grad(self):
self.check_grad_with_place(self.place, ["X"], "Out")
class TestExpandV2OpRank2_DimExpanding(TestExpandV2XPUOp):
def init_data(self):
self.ori_shape = [120]
......@@ -189,7 +192,7 @@ class TestExpandV2OpInteger(XPUOpTest):
self.check_output_with_place(self.place)
def test_check_grad(self):
pass
self.check_grad_with_place(self.place, ["X"], "Out")
# Test python API
......
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import unittest
import numpy as np
sys.path.append("..")
from op_test_xpu import XPUOpTest
from xpu.get_test_cover_info import (
XPUOpTestWrapper,
create_test_class,
get_xpu_op_support_types,
)
import paddle
paddle.enable_static()
np.random.seed(10)
class XPUTestIsNANOp(XPUOpTestWrapper):
def __init__(self):
self.op_name = 'isnan_v2'
self.use_dynamic_create_class = False
class TestIsNAN(XPUOpTest):
def setUp(self):
self.init_dtype()
self.set_xpu()
self.op_type = "isnan_v2"
self.place = paddle.XPUPlace(0)
self.set_inputs()
self.set_output()
def init_dtype(self):
self.dtype = self.in_type
def set_inputs(self):
x = np.random.uniform(0.1, 1, [11, 17]).astype(self.dtype)
x[0] = np.nan
x[-1] = np.nan
self.inputs = {'X': x}
def set_output(self):
self.outputs = {'Out': np.isnan(self.inputs['X']).astype(bool)}
def set_xpu(self):
self.__class__.use_xpu = True
self.__class__.no_need_check_grad = True
self.__class__.op_type = self.in_type
def test_check_output(self):
self.check_output_with_place(self.place)
support_types = get_xpu_op_support_types('isnan_v2')
for stype in support_types:
create_test_class(globals(), XPUTestIsNANOp, stype)
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import unittest
import numpy as np
sys.path.append("..")
from op_test_xpu import XPUOpTest
from xpu.get_test_cover_info import (
XPUOpTestWrapper,
create_test_class,
get_xpu_op_support_types,
)
import paddle
paddle.enable_static()
np.random.seed(10)
class XPUTestMeshGridOp(XPUOpTestWrapper):
def __init__(self):
self.op_name = 'meshgrid'
self.use_dynamic_create_class = False
class TestMeshGrid(XPUOpTest):
def setUp(self):
self.init_dtype()
self.set_xpu()
self.op_type = "meshgrid"
self.place = paddle.XPUPlace(0)
self.set_inputs()
self.set_output()
def init_dtype(self):
self.dtype = self.in_type
def init_test_data(self):
self.shape = self.get_x_shape()
ins = []
outs = []
for i in range(len(self.shape)):
ins.append(
np.random.random((self.shape[i],)).astype(self.dtype)
)
for i in range(len(self.shape)):
out_reshape = [1] * len(self.shape)
out_reshape[i] = self.shape[i]
out_temp = np.reshape(ins[i], out_reshape)
outs.append(np.broadcast_to(out_temp, self.shape))
return ins, outs
def get_x_shape(self):
return [100, 200]
def set_inputs(self):
ins, outs = self.init_test_data()
self.inputs = {'X': [('x%d' % i, ins[i]) for i in range(len(ins))]}
self.outputs = {
'Out': [('out%d' % i, outs[i]) for i in range(len(outs))]
}
def set_output(self):
pass
def set_xpu(self):
self.__class__.use_xpu = True
self.__class__.no_need_check_grad = True
self.__class__.op_type = self.in_type
def test_check_output(self):
self.check_output_with_place(self.place)
class TestMeshgridOp2(TestMeshGrid):
def get_x_shape(self):
return [100, 300]
support_types = get_xpu_op_support_types('meshgrid')
for stype in support_types:
create_test_class(globals(), XPUTestMeshGridOp, stype)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册