未验证 提交 48f5af99 编写于 作者: Z zhangyuqin1998 提交者: GitHub

move raw kernels to legacy (#53913)

* move raw kernels to legacy

* Update elementwise_add_kernel.cu

* fix
上级 8032d57e
......@@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/elementwise_add_kernel.h"
#include "paddle/phi/api/ext/dispatch.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/common/bfloat16.h"
......@@ -22,15 +23,43 @@
namespace phi {
// Create the definition of Add
DEFINE_CPU_ELEMENTWISE_OP(Add)
template <typename T, typename Context>
void AddFunctor(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
int axis,
DenseTensor* out) {
dev_ctx.template Alloc<T>(out);
if (x.dims() == y.dims()) {
SameDimsElementwiseCompute<SameDimsAddFunctor<CPUContext, T>>()(
dev_ctx, x, y, out);
} else {
auto x_dims = x.dims();
auto y_dims = y.dims();
if (x_dims.size() >= y_dims.size()) {
funcs::ElementwiseCompute<funcs::AddFunctor<T>, T>(
dev_ctx, x, y, funcs::AddFunctor<T>(), out, axis);
} else {
funcs::ElementwiseCompute<funcs::InverseAddFunctor<T>, T>(
dev_ctx, x, y, funcs::InverseAddFunctor<T>(), out, axis);
}
}
}
template <typename T, typename Context>
void AddKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
DenseTensor* out) {
AddFunctor<T, Context>(dev_ctx, x, y, -1, out);
}
template <typename T, typename Context>
void GradAddKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
DenseTensor* out) {
AddRawKernel<T>(dev_ctx, x, y, -1, out);
AddFunctor<T>(dev_ctx, x, y, -1, out);
}
} // namespace phi
......@@ -41,10 +70,10 @@ using complex128 = ::phi::dtype::complex<double>;
// NOTE(chenweihang): using bfloat16 will cause redefine with xpu bfloat16
// using bfloat16 = ::phi::dtype::bfloat16;
PD_REGISTER_KERNEL(add_raw,
PD_REGISTER_KERNEL(add,
CPU,
ALL_LAYOUT,
phi::AddRawKernel,
phi::AddKernel,
float,
double,
int16_t,
......
......@@ -22,9 +22,27 @@
namespace phi {
// Create the definition of Subtract
DEFINE_CPU_ELEMENTWISE_OP(Subtract)
template <typename T, typename Context>
void SubtractKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
DenseTensor* out) {
dev_ctx.template Alloc<T>(out);
if (x.dims() == y.dims()) {
SameDimsElementwiseCompute<SameDimsSubtractFunctor<CPUContext, T>>()(
dev_ctx, x, y, out);
} else {
auto x_dims = x.dims();
auto y_dims = y.dims();
if (x_dims.size() >= y_dims.size()) {
funcs::ElementwiseCompute<funcs::SubtractFunctor<T>, T>(
dev_ctx, x, y, funcs::SubtractFunctor<T>(), out, -1);
} else {
funcs::ElementwiseCompute<funcs::InverseSubtractFunctor<T>, T>(
dev_ctx, x, y, funcs::InverseSubtractFunctor<T>(), out, -1);
}
}
}
} // namespace phi
using complex64 = ::phi::dtype::complex<float>;
......@@ -33,10 +51,10 @@ using complex128 = ::phi::dtype::complex<double>;
// NOTE(chenweihang): using bfloat16 will cause redefine with xpu bfloat16
// using bfloat16 = ::phi::dtype::bfloat16;
PD_REGISTER_KERNEL(subtract_raw,
PD_REGISTER_KERNEL(subtract,
CPU,
ALL_LAYOUT,
phi::SubtractRawKernel,
phi::SubtractKernel,
float,
double,
int16_t,
......
......@@ -18,13 +18,6 @@
#include "paddle/phi/infermeta/binary.h"
namespace phi {
template <typename T, typename Context>
void AddRawKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
int axis,
DenseTensor* out);
template <typename T, typename Context>
void AddKernel(const Context& dev_ctx,
const DenseTensor& x,
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/elementwise_kernel.h"
#include "paddle/phi/kernels/elementwise_add_kernel.h"
#include "paddle/phi/kernels/elementwise_divide_kernel.h"
#include "paddle/phi/kernels/elementwise_multiply_kernel.h"
#include "paddle/phi/kernels/elementwise_subtract_kernel.h"
#include "paddle/phi/backends/all_context.h"
#include "paddle/phi/core/kernel_registry.h"
namespace phi {
template <typename T, typename Context>
void AddKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
DenseTensor* out) {
AddRawKernel<T, Context>(dev_ctx, x, y, -1, out);
}
template <typename T, typename Context>
void SubtractKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
DenseTensor* out) {
int axis = -1;
SubtractRawKernel<T>(dev_ctx, x, y, axis, out);
}
} // namespace phi
using complex64 = ::phi::dtype::complex<float>;
using complex128 = ::phi::dtype::complex<double>;
PD_REGISTER_KERNEL(subtract,
CPU,
ALL_LAYOUT,
phi::SubtractKernel,
float,
double,
int16_t,
int,
int64_t,
complex64,
complex128,
phi::dtype::bfloat16) {}
PD_REGISTER_KERNEL(add,
CPU,
ALL_LAYOUT,
phi::AddKernel,
float,
double,
int16_t,
int,
int64_t,
complex64,
complex128) {}
#if defined(PADDLE_WITH_XPU_KP) && defined(PADDLE_WITH_XPU)
PD_REGISTER_KERNEL(subtract, KPS, ALL_LAYOUT, phi::SubtractKernel, float) {}
PD_REGISTER_KERNEL(add, KPS, ALL_LAYOUT, phi::AddKernel, float) {}
#elif defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
PD_REGISTER_KERNEL(subtract,
KPS,
ALL_LAYOUT,
phi::SubtractKernel,
float,
double,
int16_t,
int,
int64_t,
phi::dtype::float16,
complex64,
complex128,
phi::dtype::bfloat16) {}
PD_REGISTER_KERNEL(add,
KPS,
ALL_LAYOUT,
phi::AddKernel,
float,
double,
int16_t,
int,
int64_t,
phi::dtype::float16,
phi::dtype::bfloat16,
complex64,
complex128) {}
#endif
#if defined(PADDLE_WITH_XPU) && !defined(PADDLE_WITH_XPU_KP)
PD_REGISTER_KERNEL(add,
XPU,
ALL_LAYOUT,
phi::AddKernel,
phi::dtype::float16,
float,
int,
int64_t) {}
PD_REGISTER_KERNEL(subtract,
XPU,
ALL_LAYOUT,
phi::SubtractKernel,
float,
phi::dtype::float16,
int64_t) {}
#endif
......@@ -19,13 +19,6 @@
namespace phi {
template <typename T, typename Context>
void SubtractRawKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
int axis,
DenseTensor* out);
template <typename T, typename Context>
void SubtractKernel(const Context& dev_ctx,
const DenseTensor& x,
......
......@@ -84,7 +84,7 @@ void CholeskySolveGradKernel(const Context& dev_ctx,
DenseTensor commonterm_conj = Conj<T, Context>(dev_ctx, commonterm);
commonterm_conj = phi::TransposeLast2Dim<T>(dev_ctx, commonterm_conj);
phi::AddRawKernel<T>(dev_ctx, commonterm, commonterm_conj, -1, &commonterm);
phi::AddKernel<T>(dev_ctx, commonterm, commonterm_conj, &commonterm);
DenseTensor dy_bst = phi::Empty<T, Context>(dev_ctx, y_bst_dims);
if (upper) {
......
......@@ -237,7 +237,7 @@ void Tensor_Add(const Context& dev_ctx,
out->Resize(src1.dims());
dev_ctx.template Alloc<T>(out);
phi::AddRawKernel<T, Context>(dev_ctx, src1, src2, -1, out);
phi::AddKernel<T, Context>(dev_ctx, src1, src2, out);
}
template <typename Context, typename T>
......@@ -248,7 +248,7 @@ void Tensor_Sub(const Context& dev_ctx,
out->Resize(src1.dims());
dev_ctx.template Alloc<T>(out);
phi::SubtractRawKernel<T, Context>(dev_ctx, src1, src2, -1, out);
phi::SubtractKernel<T, Context>(dev_ctx, src1, src2, out);
}
template <typename Context, typename T, size_t D>
......
......@@ -18,24 +18,49 @@
#include "paddle/phi/common/float16.h"
#endif
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/elementwise_add_kernel.h"
#include "paddle/phi/kernels/impl/elementwise_kernel_impl.h"
namespace phi {
DEFINE_CUDA_ELEMENTWISE_OP(Add)
template <typename T, typename Context>
void AddCudaFunctor(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
int axis,
DenseTensor* out) {
std::vector<const DenseTensor*> inputs;
inputs.reserve(2);
std::vector<DenseTensor*> outputs;
outputs.reserve(1);
inputs.emplace_back(&x);
inputs.emplace_back(&y);
outputs.emplace_back(out);
dev_ctx.template Alloc<T>(out);
funcs::BroadcastKernel<T>(
dev_ctx, inputs, &outputs, funcs::AddFunctor<T>(), axis);
}
template <typename T, typename Context>
void AddKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
DenseTensor* out) {
AddCudaFunctor<T, Context>(dev_ctx, x, y, -1, out);
}
template <typename T, typename Context>
void GradAddKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
DenseTensor* out) {
AddRawKernel<T>(dev_ctx, x, y, -1, out);
AddCudaFunctor<T>(dev_ctx, x, y, -1, out);
}
} // namespace phi
#ifdef PADDLE_WITH_XPU_KP
PD_REGISTER_KERNEL(add_raw, KPS, ALL_LAYOUT, phi::AddRawKernel, float) {}
PD_REGISTER_KERNEL(add, KPS, ALL_LAYOUT, phi::AddKernel, float) {}
#else
using float16 = phi::dtype::float16;
......@@ -43,17 +68,17 @@ using bfloat16 = phi::dtype::bfloat16;
using complex64 = ::phi::dtype::complex<float>;
using complex128 = ::phi::dtype::complex<double>;
PD_REGISTER_KERNEL(add_raw,
PD_REGISTER_KERNEL(add,
KPS,
ALL_LAYOUT,
phi::AddRawKernel,
phi::AddKernel,
float,
double,
int16_t,
int,
int64_t,
float16,
bfloat16,
phi::dtype::float16,
phi::dtype::bfloat16,
complex64,
complex128) {}
......
......@@ -22,14 +22,27 @@
namespace phi {
// Create the definition of Subtract
DEFINE_CUDA_ELEMENTWISE_OP(Subtract)
template <typename T, typename Context>
void SubtractKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
DenseTensor* out) {
std::vector<const DenseTensor*> inputs;
inputs.reserve(2);
std::vector<DenseTensor*> outputs;
outputs.reserve(1);
inputs.emplace_back(&x);
inputs.emplace_back(&y);
outputs.emplace_back(out);
dev_ctx.template Alloc<T>(out);
funcs::BroadcastKernel<T>(
dev_ctx, inputs, &outputs, funcs::SubtractFunctor<T>(), -1);
}
} // namespace phi
#ifdef PADDLE_WITH_XPU_KP
PD_REGISTER_KERNEL(
subtract_raw, KPS, ALL_LAYOUT, phi::SubtractRawKernel, float) {}
PD_REGISTER_KERNEL(subtract, KPS, ALL_LAYOUT, phi::SubtractKernel, float) {}
#else
using float16 = phi::dtype::float16;
......@@ -37,10 +50,10 @@ using bfloat16 = phi::dtype::bfloat16;
using complex64 = ::phi::dtype::complex<float>;
using complex128 = ::phi::dtype::complex<double>;
PD_REGISTER_KERNEL(subtract_raw,
PD_REGISTER_KERNEL(subtract,
KPS,
ALL_LAYOUT,
phi::SubtractRawKernel,
phi::SubtractKernel,
float,
double,
int16_t,
......
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/api/ext/dispatch.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/common/bfloat16.h"
#include "paddle/phi/common/complex.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/cpu/elementwise.h"
#include "paddle/phi/kernels/impl/elementwise_kernel_impl.h"
namespace phi {
// Create the definition of Add
DEFINE_CPU_ELEMENTWISE_OP(Add)
} // namespace phi
using complex64 = ::phi::dtype::complex<float>;
using complex128 = ::phi::dtype::complex<double>;
// NOTE(chenweihang): using bfloat16 will cause redefine with xpu bfloat16
// using bfloat16 = ::phi::dtype::bfloat16;
PD_REGISTER_KERNEL(add_raw,
CPU,
ALL_LAYOUT,
phi::AddRawKernel,
float,
double,
int16_t,
int,
int64_t,
complex64,
complex128) {}
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/api/ext/dispatch.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/common/bfloat16.h"
#include "paddle/phi/common/complex.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/cpu/elementwise.h"
#include "paddle/phi/kernels/impl/elementwise_kernel_impl.h"
namespace phi {
// Create the definition of Subtract
DEFINE_CPU_ELEMENTWISE_OP(Subtract)
} // namespace phi
using complex64 = ::phi::dtype::complex<float>;
using complex128 = ::phi::dtype::complex<double>;
// NOTE(chenweihang): using bfloat16 will cause redefine with xpu bfloat16
// using bfloat16 = ::phi::dtype::bfloat16;
PD_REGISTER_KERNEL(subtract_raw,
CPU,
ALL_LAYOUT,
phi::SubtractRawKernel,
float,
double,
int16_t,
int,
int64_t,
complex64,
complex128,
phi::dtype::bfloat16) {}
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/infermeta/binary.h"
namespace phi {
template <typename T, typename Context>
void AddRawKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
int axis,
DenseTensor* out);
} // namespace phi
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/infermeta/binary.h"
namespace phi {
template <typename T, typename Context>
void SubtractRawKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
int axis,
DenseTensor* out);
} // namespace phi
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/backends/gpu/gpu_context.h"
#ifndef PADDLE_WITH_XPU_KP
#include "paddle/phi/common/complex.h"
#include "paddle/phi/common/float16.h"
#endif
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/elementwise_kernel_impl.h"
namespace phi {
DEFINE_CUDA_ELEMENTWISE_OP(Add)
} // namespace phi
#ifdef PADDLE_WITH_XPU_KP
PD_REGISTER_KERNEL(add_raw, KPS, ALL_LAYOUT, phi::AddRawKernel, float) {}
#else
using float16 = phi::dtype::float16;
using bfloat16 = phi::dtype::bfloat16;
using complex64 = ::phi::dtype::complex<float>;
using complex128 = ::phi::dtype::complex<double>;
PD_REGISTER_KERNEL(add_raw,
KPS,
ALL_LAYOUT,
phi::AddRawKernel,
float,
double,
int16_t,
int,
int64_t,
float16,
bfloat16,
complex64,
complex128) {}
#endif
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/backends/gpu/gpu_context.h"
#ifndef PADDLE_WITH_XPU_KP
#include "paddle/phi/common/complex.h"
#include "paddle/phi/common/float16.h"
#endif
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/elementwise_kernel_impl.h"
namespace phi {
// Create the definition of Subtract
DEFINE_CUDA_ELEMENTWISE_OP(Subtract)
} // namespace phi
#ifdef PADDLE_WITH_XPU_KP
PD_REGISTER_KERNEL(
subtract_raw, KPS, ALL_LAYOUT, phi::SubtractRawKernel, float) {}
#else
using float16 = phi::dtype::float16;
using bfloat16 = phi::dtype::bfloat16;
using complex64 = ::phi::dtype::complex<float>;
using complex128 = ::phi::dtype::complex<double>;
PD_REGISTER_KERNEL(subtract_raw,
KPS,
ALL_LAYOUT,
phi::SubtractRawKernel,
float,
double,
int16_t,
int,
int64_t,
float16,
bfloat16,
complex64,
complex128) {}
#endif
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/elementwise_add_kernel.h"
#include <memory>
#include <string>
#include "paddle/phi/api/ext/dispatch.h"
#include "paddle/phi/backends/xpu/enforce_xpu.h"
#include "paddle/phi/backends/xpu/xpu_context.h"
#include "paddle/phi/backends/xpu/xpu_header.h"
#include "paddle/phi/backends/xpu/xpu_info.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/elementwise_kernel_impl.h"
#include "paddle/phi/kernels/xpu/elementwise.h"
namespace phi {
template <typename T, typename Context>
void AddRawKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
int axis,
DenseTensor* out) {
using XPUType = typename XPUTypeTrait<T>::Type;
auto f = [](xpu::Context* ctx,
const XPUType* x,
const XPUType* y,
XPUType* z,
const std::vector<int>& xshape,
const std::vector<int>& yshape) {
return xpu::broadcast_add<XPUType>(ctx, x, y, z, xshape, yshape);
};
XPUElementwise<T, XPUType>(dev_ctx, x, y, axis, out, f);
}
} // namespace phi
PD_REGISTER_KERNEL(add_raw,
XPU,
ALL_LAYOUT,
phi::AddRawKernel,
phi::dtype::float16,
float,
int,
int64_t) {}
/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/phi/kernels/elementwise_subtract_kernel.h"
#include "paddle/phi/backends/xpu/xpu_context.h"
#include "paddle/phi/backends/xpu/xpu_header.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/xpu/elementwise.h"
namespace phi {
template <typename T, typename Context>
void SubtractRawKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
int axis,
DenseTensor* out) {
using XPUType = typename XPUTypeTrait<T>::Type;
auto f = [](xpu::Context* ctx,
const XPUType* x,
const XPUType* y,
XPUType* z,
const std::vector<int>& xshape,
const std::vector<int>& yshape) {
return xpu::broadcast_sub<XPUType>(ctx, x, y, z, xshape, yshape);
};
phi::XPUElementwise<T, XPUType>(dev_ctx, x, y, axis, out, f);
}
} // namespace phi
PD_REGISTER_KERNEL(subtract_raw,
XPU,
ALL_LAYOUT,
phi::SubtractRawKernel,
float,
phi::dtype::float16,
int64_t) {}
......@@ -28,6 +28,25 @@
namespace phi {
template <typename T, typename Context>
void AddKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
DenseTensor* out) {
using XPUType = typename XPUTypeTrait<T>::Type;
auto f = [](xpu::Context* ctx,
const XPUType* x,
const XPUType* y,
XPUType* z,
const std::vector<int>& xshape,
const std::vector<int>& yshape) {
return xpu::broadcast_add<XPUType>(ctx, x, y, z, xshape, yshape);
};
XPUElementwise<T, XPUType>(dev_ctx, x, y, -1, out, f);
}
template <typename T, typename Context>
void GradAddXPUKernel(const Context& dev_ctx,
const DenseTensor& x,
......@@ -47,26 +66,6 @@ void GradAddXPUKernel(const Context& dev_ctx,
PADDLE_ENFORCE_XDNN_SUCCESS(r, "broadcast_add");
}
template <typename T, typename Context>
void AddRawKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
int axis,
DenseTensor* out) {
using XPUType = typename XPUTypeTrait<T>::Type;
auto f = [](xpu::Context* ctx,
const XPUType* x,
const XPUType* y,
XPUType* z,
const std::vector<int>& xshape,
const std::vector<int>& yshape) {
return xpu::broadcast_add<XPUType>(ctx, x, y, z, xshape, yshape);
};
XPUElementwise<T, XPUType>(dev_ctx, x, y, axis, out, f);
}
} // namespace phi
PD_REGISTER_KERNEL(grad_add,
......@@ -75,10 +74,11 @@ PD_REGISTER_KERNEL(grad_add,
phi::GradAddXPUKernel,
phi::dtype::float16,
float) {}
PD_REGISTER_KERNEL(add_raw,
PD_REGISTER_KERNEL(add,
XPU,
ALL_LAYOUT,
phi::AddRawKernel,
phi::AddKernel,
phi::dtype::float16,
float,
int,
......
......@@ -20,11 +20,10 @@ limitations under the License. */
namespace phi {
template <typename T, typename Context>
void SubtractRawKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
int axis,
DenseTensor* out) {
void SubtractKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
DenseTensor* out) {
using XPUType = typename XPUTypeTrait<T>::Type;
auto f = [](xpu::Context* ctx,
const XPUType* x,
......@@ -35,14 +34,14 @@ void SubtractRawKernel(const Context& dev_ctx,
return xpu::broadcast_sub<XPUType>(ctx, x, y, z, xshape, yshape);
};
phi::XPUElementwise<T, XPUType>(dev_ctx, x, y, axis, out, f);
phi::XPUElementwise<T, XPUType>(dev_ctx, x, y, -1, out, f);
}
} // namespace phi
PD_REGISTER_KERNEL(subtract_raw,
PD_REGISTER_KERNEL(subtract,
XPU,
ALL_LAYOUT,
phi::SubtractRawKernel,
phi::SubtractKernel,
float,
phi::dtype::float16,
int64_t) {}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册