未验证 提交 051add42 编写于 作者: Z zhangyuqin1998 提交者: GitHub

Move raw kernel to legacy (#53830)

上级 4a4ffe9a
...@@ -23,10 +23,9 @@ ...@@ -23,10 +23,9 @@
namespace phi { namespace phi {
template <typename T, typename Context> template <typename T, typename Context>
void DivideRawKernel(const Context& dev_ctx, void DivideKernel(const Context& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
const DenseTensor& y, const DenseTensor& y,
int axis,
DenseTensor* out) { DenseTensor* out) {
// allocate memory for out // allocate memory for out
dev_ctx.template Alloc<T>(out); dev_ctx.template Alloc<T>(out);
...@@ -38,10 +37,10 @@ void DivideRawKernel(const Context& dev_ctx, ...@@ -38,10 +37,10 @@ void DivideRawKernel(const Context& dev_ctx,
auto y_dims = y.dims(); auto y_dims = y.dims();
if (x_dims.size() >= y_dims.size()) { if (x_dims.size() >= y_dims.size()) {
funcs::ElementwiseCompute<funcs::DivideFunctor<T>, T>( funcs::ElementwiseCompute<funcs::DivideFunctor<T>, T>(
dev_ctx, x, y, funcs::DivideFunctor<T>(), out, axis); dev_ctx, x, y, funcs::DivideFunctor<T>(), out, -1);
} else { } else {
funcs::ElementwiseCompute<funcs::InverseDivideFunctor<T>, T>( funcs::ElementwiseCompute<funcs::InverseDivideFunctor<T>, T>(
dev_ctx, x, y, funcs::InverseDivideFunctor<T>(), out, axis); dev_ctx, x, y, funcs::InverseDivideFunctor<T>(), out, -1);
} }
} }
} }
...@@ -54,10 +53,10 @@ using complex128 = ::phi::dtype::complex<double>; ...@@ -54,10 +53,10 @@ using complex128 = ::phi::dtype::complex<double>;
// NOTE(chenweihang): using bfloat16 will cause redefine with xpu bfloat16 // NOTE(chenweihang): using bfloat16 will cause redefine with xpu bfloat16
// using bfloat16 = ::phi::dtype::bfloat16; // using bfloat16 = ::phi::dtype::bfloat16;
PD_REGISTER_KERNEL(divide_raw, PD_REGISTER_KERNEL(divide,
CPU, CPU,
ALL_LAYOUT, ALL_LAYOUT,
phi::DivideRawKernel, phi::DivideKernel,
float, float,
double, double,
int, int,
......
...@@ -22,8 +22,27 @@ ...@@ -22,8 +22,27 @@
namespace phi { namespace phi {
// Create the definition of Multiply template <typename T, typename Context>
DEFINE_CPU_ELEMENTWISE_OP(Multiply) void MultiplyKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
DenseTensor* out) {
dev_ctx.template Alloc<T>(out);
if (x.dims() == y.dims()) {
SameDimsElementwiseCompute<SameDimsMultiplyFunctor<CPUContext, T>>()(
dev_ctx, x, y, out);
} else {
auto x_dims = x.dims();
auto y_dims = y.dims();
if (x_dims.size() >= y_dims.size()) {
funcs::ElementwiseCompute<funcs::MultiplyFunctor<T>, T>(
dev_ctx, x, y, funcs::MultiplyFunctor<T>(), out, -1);
} else {
funcs::ElementwiseCompute<funcs::InverseMultiplyFunctor<T>, T>(
dev_ctx, x, y, funcs::InverseMultiplyFunctor<T>(), out, -1);
}
}
}
} // namespace phi } // namespace phi
...@@ -33,10 +52,10 @@ using complex128 = ::phi::dtype::complex<double>; ...@@ -33,10 +52,10 @@ using complex128 = ::phi::dtype::complex<double>;
// NOTE(chenweihang): using bfloat16 will cause redefine with xpu bfloat16 // NOTE(chenweihang): using bfloat16 will cause redefine with xpu bfloat16
// using bfloat16 = ::phi::dtype::bfloat16; // using bfloat16 = ::phi::dtype::bfloat16;
PD_REGISTER_KERNEL(multiply_raw, PD_REGISTER_KERNEL(multiply,
CPU, CPU,
ALL_LAYOUT, ALL_LAYOUT,
phi::MultiplyRawKernel, phi::MultiplyKernel,
float, float,
double, double,
int, int,
......
...@@ -19,13 +19,6 @@ ...@@ -19,13 +19,6 @@
namespace phi { namespace phi {
template <typename T, typename Context>
void DivideRawKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
int axis,
DenseTensor* out);
template <typename T, typename Context> template <typename T, typename Context>
void DivideKernel(const Context& dev_ctx, void DivideKernel(const Context& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
......
...@@ -23,22 +23,6 @@ ...@@ -23,22 +23,6 @@
namespace phi { namespace phi {
template <typename T, typename Context>
void DivideKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
DenseTensor* out) {
DivideRawKernel<T, Context>(dev_ctx, x, y, -1, out);
}
template <typename T, typename Context>
void MultiplyKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
DenseTensor* out) {
MultiplyRawKernel<T, Context>(dev_ctx, x, y, -1, out);
}
template <typename T, typename Context> template <typename T, typename Context>
void AddKernel(const Context& dev_ctx, void AddKernel(const Context& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
...@@ -85,35 +69,9 @@ PD_REGISTER_KERNEL(add, ...@@ -85,35 +69,9 @@ PD_REGISTER_KERNEL(add,
complex64, complex64,
complex128) {} complex128) {}
PD_REGISTER_KERNEL(multiply,
CPU,
ALL_LAYOUT,
phi::MultiplyKernel,
float,
double,
int,
int64_t,
bool,
complex64,
complex128,
phi::dtype::bfloat16) {}
PD_REGISTER_KERNEL(divide,
CPU,
ALL_LAYOUT,
phi::DivideKernel,
float,
double,
int,
int64_t,
complex64,
complex128) {}
#if defined(PADDLE_WITH_XPU_KP) && defined(PADDLE_WITH_XPU) #if defined(PADDLE_WITH_XPU_KP) && defined(PADDLE_WITH_XPU)
PD_REGISTER_KERNEL(subtract, KPS, ALL_LAYOUT, phi::SubtractKernel, float) {} PD_REGISTER_KERNEL(subtract, KPS, ALL_LAYOUT, phi::SubtractKernel, float) {}
PD_REGISTER_KERNEL(add, KPS, ALL_LAYOUT, phi::AddKernel, float) {} PD_REGISTER_KERNEL(add, KPS, ALL_LAYOUT, phi::AddKernel, float) {}
PD_REGISTER_KERNEL(multiply, KPS, ALL_LAYOUT, phi::MultiplyKernel, float) {}
PD_REGISTER_KERNEL(divide, KPS, ALL_LAYOUT, phi::DivideKernel, float) {}
#elif defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) #elif defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
PD_REGISTER_KERNEL(subtract, PD_REGISTER_KERNEL(subtract,
KPS, KPS,
...@@ -142,39 +100,10 @@ PD_REGISTER_KERNEL(add, ...@@ -142,39 +100,10 @@ PD_REGISTER_KERNEL(add,
phi::dtype::bfloat16, phi::dtype::bfloat16,
complex64, complex64,
complex128) {} complex128) {}
PD_REGISTER_KERNEL(multiply,
KPS,
ALL_LAYOUT,
phi::MultiplyKernel,
float,
double,
int,
int64_t,
bool,
phi::dtype::float16,
phi::dtype::bfloat16,
complex64,
complex128) {}
PD_REGISTER_KERNEL(divide,
KPS,
ALL_LAYOUT,
phi::DivideKernel,
float,
double,
int,
int64_t,
phi::dtype::float16,
phi::dtype::bfloat16,
complex64,
complex128) {}
#endif #endif
#if defined(PADDLE_WITH_XPU) && !defined(PADDLE_WITH_XPU_KP) #if defined(PADDLE_WITH_XPU) && !defined(PADDLE_WITH_XPU_KP)
PD_REGISTER_KERNEL(
divide, XPU, ALL_LAYOUT, phi::DivideKernel, phi::dtype::float16, float) {}
PD_REGISTER_KERNEL(add, PD_REGISTER_KERNEL(add,
XPU, XPU,
ALL_LAYOUT, ALL_LAYOUT,
...@@ -184,14 +113,6 @@ PD_REGISTER_KERNEL(add, ...@@ -184,14 +113,6 @@ PD_REGISTER_KERNEL(add,
int, int,
int64_t) {} int64_t) {}
PD_REGISTER_KERNEL(multiply,
XPU,
ALL_LAYOUT,
phi::MultiplyKernel,
phi::dtype::float16,
float,
int,
int64_t) {}
PD_REGISTER_KERNEL(subtract, PD_REGISTER_KERNEL(subtract,
XPU, XPU,
ALL_LAYOUT, ALL_LAYOUT,
......
...@@ -19,13 +19,6 @@ ...@@ -19,13 +19,6 @@
namespace phi { namespace phi {
template <typename T, typename Context>
void MultiplyRawKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
int axis,
DenseTensor* out);
template <typename T, typename Context> template <typename T, typename Context>
void MultiplyKernel(const Context& dev_ctx, void MultiplyKernel(const Context& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
......
...@@ -22,13 +22,27 @@ ...@@ -22,13 +22,27 @@
namespace phi { namespace phi {
// Create the definition of Divide template <typename T, typename Context>
DEFINE_CUDA_ELEMENTWISE_OP(Divide) void DivideKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
DenseTensor* out) {
std::vector<const DenseTensor*> inputs;
inputs.reserve(2);
std::vector<DenseTensor*> outputs;
outputs.reserve(1);
inputs.emplace_back(&x);
inputs.emplace_back(&y);
outputs.emplace_back(out);
dev_ctx.template Alloc<T>(out);
funcs::BroadcastKernel<T>(
dev_ctx, inputs, &outputs, funcs::DivideFunctor<T>(), -1);
}
} // namespace phi } // namespace phi
#ifdef PADDLE_WITH_XPU_KP #ifdef PADDLE_WITH_XPU_KP
PD_REGISTER_KERNEL(divide_raw, KPS, ALL_LAYOUT, phi::DivideRawKernel, float) {} PD_REGISTER_KERNEL(divide, KPS, ALL_LAYOUT, phi::DivideKernel, float) {}
#else #else
using float16 = phi::dtype::float16; using float16 = phi::dtype::float16;
...@@ -36,10 +50,10 @@ using bfloat16 = phi::dtype::bfloat16; ...@@ -36,10 +50,10 @@ using bfloat16 = phi::dtype::bfloat16;
using complex64 = ::phi::dtype::complex<float>; using complex64 = ::phi::dtype::complex<float>;
using complex128 = ::phi::dtype::complex<double>; using complex128 = ::phi::dtype::complex<double>;
PD_REGISTER_KERNEL(divide_raw, PD_REGISTER_KERNEL(divide,
KPS, KPS,
ALL_LAYOUT, ALL_LAYOUT,
phi::DivideRawKernel, phi::DivideKernel,
float, float,
double, double,
int, int,
......
...@@ -22,14 +22,27 @@ ...@@ -22,14 +22,27 @@
namespace phi { namespace phi {
// Create the definition of Multiply template <typename T, typename Context>
DEFINE_CUDA_ELEMENTWISE_OP(Multiply) void MultiplyKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
DenseTensor* out) {
std::vector<const DenseTensor*> inputs;
inputs.reserve(2);
std::vector<DenseTensor*> outputs;
outputs.reserve(1);
inputs.emplace_back(&x);
inputs.emplace_back(&y);
outputs.emplace_back(out);
dev_ctx.template Alloc<T>(out);
funcs::BroadcastKernel<T>(
dev_ctx, inputs, &outputs, funcs::MultiplyFunctor<T>(), -1);
}
} // namespace phi } // namespace phi
#ifdef PADDLE_WITH_XPU_KP #ifdef PADDLE_WITH_XPU_KP
PD_REGISTER_KERNEL( PD_REGISTER_KERNEL(multiply, KPS, ALL_LAYOUT, phi::MultiplyKernel, float) {}
multiply_raw, KPS, ALL_LAYOUT, phi::MultiplyRawKernel, float) {}
#else #else
using float16 = phi::dtype::float16; using float16 = phi::dtype::float16;
...@@ -37,10 +50,10 @@ using bfloat16 = phi::dtype::bfloat16; ...@@ -37,10 +50,10 @@ using bfloat16 = phi::dtype::bfloat16;
using complex64 = ::phi::dtype::complex<float>; using complex64 = ::phi::dtype::complex<float>;
using complex128 = ::phi::dtype::complex<double>; using complex128 = ::phi::dtype::complex<double>;
PD_REGISTER_KERNEL(multiply_raw, PD_REGISTER_KERNEL(multiply,
KPS, KPS,
ALL_LAYOUT, ALL_LAYOUT,
phi::MultiplyRawKernel, phi::MultiplyKernel,
float, float,
double, double,
int, int,
......
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/api/ext/dispatch.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/common/bfloat16.h"
#include "paddle/phi/common/complex.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/cpu/elementwise.h"
#include "paddle/phi/kernels/impl/elementwise_kernel_impl.h"
namespace phi {
template <typename T, typename Context>
void DivideRawKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
int axis,
DenseTensor* out) {
// allocate memory for out
dev_ctx.template Alloc<T>(out);
if (x.dims() == y.dims() && std::is_floating_point<T>::value) {
SameDimsElementwiseCompute<SameDimsDivideFunctor<CPUContext, T>>()(
dev_ctx, x, y, out);
} else {
auto x_dims = x.dims();
auto y_dims = y.dims();
if (x_dims.size() >= y_dims.size()) {
funcs::ElementwiseCompute<funcs::DivideFunctor<T>, T>(
dev_ctx, x, y, funcs::DivideFunctor<T>(), out, axis);
} else {
funcs::ElementwiseCompute<funcs::InverseDivideFunctor<T>, T>(
dev_ctx, x, y, funcs::InverseDivideFunctor<T>(), out, axis);
}
}
}
} // namespace phi
using complex64 = ::phi::dtype::complex<float>;
using complex128 = ::phi::dtype::complex<double>;
// NOTE(chenweihang): using bfloat16 will cause redefine with xpu bfloat16
// using bfloat16 = ::phi::dtype::bfloat16;
PD_REGISTER_KERNEL(divide_raw,
CPU,
ALL_LAYOUT,
phi::DivideRawKernel,
float,
double,
int,
int64_t,
complex64,
complex128) {}
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/api/ext/dispatch.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/common/bfloat16.h"
#include "paddle/phi/common/complex.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/cpu/elementwise.h"
#include "paddle/phi/kernels/impl/elementwise_kernel_impl.h"
namespace phi {
// Create the definition of Multiply
DEFINE_CPU_ELEMENTWISE_OP(Multiply)
} // namespace phi
using complex64 = ::phi::dtype::complex<float>;
using complex128 = ::phi::dtype::complex<double>;
// NOTE(chenweihang): using bfloat16 will cause redefine with xpu bfloat16
// using bfloat16 = ::phi::dtype::bfloat16;
PD_REGISTER_KERNEL(multiply_raw,
CPU,
ALL_LAYOUT,
phi::MultiplyRawKernel,
float,
double,
int,
int64_t,
bool,
complex64,
complex128,
phi::dtype::bfloat16) {}
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
template <typename T, typename Context>
void MultiplyRawKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
int axis,
DenseTensor* out);
} // namespace phi
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/backends/gpu/gpu_context.h"
#ifndef PADDLE_WITH_XPU_KP
#include "paddle/phi/common/complex.h"
#include "paddle/phi/common/float16.h"
#endif
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/elementwise_kernel_impl.h"
namespace phi {
// Create the definition of Divide
DEFINE_CUDA_ELEMENTWISE_OP(Divide)
} // namespace phi
#ifdef PADDLE_WITH_XPU_KP
PD_REGISTER_KERNEL(divide_raw, KPS, ALL_LAYOUT, phi::DivideRawKernel, float) {}
#else
using float16 = phi::dtype::float16;
using bfloat16 = phi::dtype::bfloat16;
using complex64 = ::phi::dtype::complex<float>;
using complex128 = ::phi::dtype::complex<double>;
PD_REGISTER_KERNEL(divide_raw,
KPS,
ALL_LAYOUT,
phi::DivideRawKernel,
float,
double,
int,
int64_t,
float16,
bfloat16,
complex64,
complex128) {}
#endif
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/backends/gpu/gpu_context.h"
#ifndef PADDLE_WITH_XPU_KP
#include "paddle/phi/common/complex.h"
#include "paddle/phi/common/float16.h"
#endif
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/elementwise_kernel_impl.h"
namespace phi {
// Create the definition of Multiply
DEFINE_CUDA_ELEMENTWISE_OP(Multiply)
} // namespace phi
#ifdef PADDLE_WITH_XPU_KP
PD_REGISTER_KERNEL(
multiply_raw, KPS, ALL_LAYOUT, phi::MultiplyRawKernel, float) {}
#else
using float16 = phi::dtype::float16;
using bfloat16 = phi::dtype::bfloat16;
using complex64 = ::phi::dtype::complex<float>;
using complex128 = ::phi::dtype::complex<double>;
PD_REGISTER_KERNEL(multiply_raw,
KPS,
ALL_LAYOUT,
phi::MultiplyRawKernel,
float,
double,
int,
int64_t,
bool,
float16,
complex64,
complex128,
bfloat16) {}
#endif
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/elementwise_divide_kernel.h"
#include <memory>
#include <string>
#include "paddle/phi/backends/xpu/xpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/elementwise_base.h"
#include "paddle/phi/kernels/xpu/elementwise.h"
namespace phi {
template <typename T, typename Context>
void DivideRawKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
int axis,
DenseTensor* out) {
using XPUType = typename XPUTypeTrait<T>::Type;
auto f = [](xpu::Context* ctx,
const XPUType* x,
const XPUType* y,
XPUType* z,
const std::vector<int>& xshape,
const std::vector<int>& yshape) {
return xpu::broadcast_div<XPUType>(ctx, x, y, z, xshape, yshape);
};
XPUElementwise<T, XPUType>(dev_ctx, x, y, axis, out, f);
}
} // namespace phi
PD_REGISTER_KERNEL(divide_raw,
XPU,
ALL_LAYOUT,
phi::DivideRawKernel,
phi::dtype::float16,
float) {}
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/elementwise_multiply_kernel.h"
#include <memory>
#include <string>
#include "paddle/phi/backends/xpu/xpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/elementwise_base.h"
#include "paddle/phi/kernels/xpu/elementwise.h"
namespace phi {
template <typename T, typename Context>
void MultiplyRawKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
int axis,
DenseTensor* out) {
using XPUType = typename XPUTypeTrait<T>::Type;
auto f = [](xpu::Context* ctx,
const XPUType* x,
const XPUType* y,
XPUType* z,
const std::vector<int>& xshape,
const std::vector<int>& yshape) {
return xpu::broadcast_mul<XPUType>(ctx, x, y, z, xshape, yshape);
};
XPUElementwise<T, XPUType>(dev_ctx, x, y, axis, out, f);
}
} // namespace phi
PD_REGISTER_KERNEL(multiply_raw,
XPU,
ALL_LAYOUT,
phi::MultiplyRawKernel,
phi::dtype::float16,
float,
int,
int64_t) {}
...@@ -16,6 +16,7 @@ limitations under the License. */ ...@@ -16,6 +16,7 @@ limitations under the License. */
#include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/selected_rows.h" #include "paddle/phi/core/selected_rows.h"
#include "paddle/phi/kernels/legacy/elementwise_multipy_kernel.h"
namespace phi { namespace phi {
namespace sr { namespace sr {
......
...@@ -25,10 +25,9 @@ ...@@ -25,10 +25,9 @@
namespace phi { namespace phi {
template <typename T, typename Context> template <typename T, typename Context>
void DivideRawKernel(const Context& dev_ctx, void DivideKernel(const Context& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
const DenseTensor& y, const DenseTensor& y,
int axis,
DenseTensor* out) { DenseTensor* out) {
using XPUType = typename XPUTypeTrait<T>::Type; using XPUType = typename XPUTypeTrait<T>::Type;
auto f = [](xpu::Context* ctx, auto f = [](xpu::Context* ctx,
...@@ -40,14 +39,10 @@ void DivideRawKernel(const Context& dev_ctx, ...@@ -40,14 +39,10 @@ void DivideRawKernel(const Context& dev_ctx,
return xpu::broadcast_div<XPUType>(ctx, x, y, z, xshape, yshape); return xpu::broadcast_div<XPUType>(ctx, x, y, z, xshape, yshape);
}; };
XPUElementwise<T, XPUType>(dev_ctx, x, y, axis, out, f); XPUElementwise<T, XPUType>(dev_ctx, x, y, -1, out, f);
} }
} // namespace phi } // namespace phi
PD_REGISTER_KERNEL(divide_raw, PD_REGISTER_KERNEL(
XPU, divide, XPU, ALL_LAYOUT, phi::DivideKernel, phi::dtype::float16, float) {}
ALL_LAYOUT,
phi::DivideRawKernel,
phi::dtype::float16,
float) {}
...@@ -25,10 +25,9 @@ ...@@ -25,10 +25,9 @@
namespace phi { namespace phi {
template <typename T, typename Context> template <typename T, typename Context>
void MultiplyRawKernel(const Context& dev_ctx, void MultiplyKernel(const Context& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
const DenseTensor& y, const DenseTensor& y,
int axis,
DenseTensor* out) { DenseTensor* out) {
using XPUType = typename XPUTypeTrait<T>::Type; using XPUType = typename XPUTypeTrait<T>::Type;
auto f = [](xpu::Context* ctx, auto f = [](xpu::Context* ctx,
...@@ -40,15 +39,15 @@ void MultiplyRawKernel(const Context& dev_ctx, ...@@ -40,15 +39,15 @@ void MultiplyRawKernel(const Context& dev_ctx,
return xpu::broadcast_mul<XPUType>(ctx, x, y, z, xshape, yshape); return xpu::broadcast_mul<XPUType>(ctx, x, y, z, xshape, yshape);
}; };
XPUElementwise<T, XPUType>(dev_ctx, x, y, axis, out, f); XPUElementwise<T, XPUType>(dev_ctx, x, y, -1, out, f);
} }
} // namespace phi } // namespace phi
PD_REGISTER_KERNEL(multiply_raw, PD_REGISTER_KERNEL(multiply,
XPU, XPU,
ALL_LAYOUT, ALL_LAYOUT,
phi::MultiplyRawKernel, phi::MultiplyKernel,
phi::dtype::float16, phi::dtype::float16,
float, float,
int, int,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册