未验证 提交 ee003457 编写于 作者: W wz1qqx 提交者: GitHub

[XPU]Add flip kernel (#55932)

上级 eadc5d07
...@@ -91,7 +91,8 @@ AddLayernormXPUPattern::AddLayernormXPUPattern(PDPattern* pattern, ...@@ -91,7 +91,8 @@ AddLayernormXPUPattern::AddLayernormXPUPattern(PDPattern* pattern,
->AsInput(); ->AsInput();
auto ele_out = pattern->NewNode(ele_out_repr()) auto ele_out = pattern->NewNode(ele_out_repr())
->assert_is_op_output("elementwise_add", "Out") ->assert_is_op_output("elementwise_add", "Out")
->assert_is_op_input("layer_norm", "X"); ->assert_is_op_input("layer_norm", "X")
->assert_has_n_outputs(1);
ele_add->LinksFrom({add_x, add_y}).LinksTo({ele_out}); ele_add->LinksFrom({add_x, add_y}).LinksTo({ele_out});
auto l_norm = pattern->NewNode(l_norm_repr())->assert_is_op("layer_norm"); auto l_norm = pattern->NewNode(l_norm_repr())->assert_is_op("layer_norm");
auto norm_bias = pattern->NewNode(norm_bias_repr()) auto norm_bias = pattern->NewNode(norm_bias_repr())
......
...@@ -24,7 +24,8 @@ XPUOpMap& get_kl2_ops() { ...@@ -24,7 +24,8 @@ XPUOpMap& get_kl2_ops() {
static XPUOpMap s_xpu2_kernels{ static XPUOpMap s_xpu2_kernels{
{"add_act_xpu", {"add_act_xpu",
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
{"add_layernorm_xpu", XPUKernelSet({phi::DataType::FLOAT32})}, {"add_layernorm_xpu",
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
{"abs", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, {"abs", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
{"abs_grad", {"abs_grad",
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
...@@ -371,6 +372,7 @@ XPUOpMap& get_kl2_ops() { ...@@ -371,6 +372,7 @@ XPUOpMap& get_kl2_ops() {
phi::DataType::INT32, phi::DataType::INT32,
phi::DataType::INT8, phi::DataType::INT8,
phi::DataType::FLOAT32})}, phi::DataType::FLOAT32})},
{"flip", XPUKernelSet({phi::DataType::FLOAT32})},
{"full_batch_size_like", {"full_batch_size_like",
XPUKernelSet({phi::DataType::INT64, XPUKernelSet({phi::DataType::INT64,
phi::DataType::INT32, phi::DataType::INT32,
......
...@@ -119,4 +119,5 @@ PD_REGISTER_KERNEL(add_layernorm_xpu, ...@@ -119,4 +119,5 @@ PD_REGISTER_KERNEL(add_layernorm_xpu,
XPU, XPU,
ALL_LAYOUT, ALL_LAYOUT,
phi::fusion::AddLayernormXPUKernel, phi::fusion::AddLayernormXPUKernel,
float) {} float,
phi::dtype::float16) {}
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/flip_kernel.h"
#include "paddle/phi/backends/xpu/enforce_xpu.h"
#include "paddle/phi/core/kernel_registry.h"
namespace phi {
template <typename T, typename Context>
void FlipKernel(const Context& dev_ctx,
const DenseTensor& x,
const std::vector<int>& axis,
DenseTensor* out) {
using XPUInTDType = typename XPUTypeTrait<T>::Type;
int x_rank = x.dims().size();
std::vector<int64_t> formated_axis(std::begin(axis), std::end(axis));
for (size_t i = 0; i < axis.size(); i++) {
if (axis[i] < 0) {
formated_axis[i] = static_cast<int64_t>(axis[i] + x_rank);
}
}
dev_ctx.template Alloc<T>(out);
if (out->numel() == 0) {
return;
}
if (formated_axis.size() == 0) {
phi::Copy<Context>(dev_ctx, x, dev_ctx.GetPlace(), false, out);
return;
}
std::vector<int64_t> x_shape = phi::vectorize(x.dims());
auto x_data = reinterpret_cast<const XPUInTDType*>(x.data<T>());
auto out_data = reinterpret_cast<XPUInTDType*>(out->data<T>());
auto numel = x.numel();
if (numel <= 0) {
return;
}
int r = xpu::flip<XPUInTDType>(
/* Context* ctx */ dev_ctx.x_context(),
/* const T* x */ x_data,
/* T* y */ out_data,
/* const std::vector<int64_t>& xshape */ x_shape,
/* const std::vector<int64_t>& axis */ formated_axis);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "flip");
}
} // namespace phi
PD_REGISTER_KERNEL(flip, XPU, ALL_LAYOUT, phi::FlipKernel, float) {}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册