未验证 提交 a3d56a9c 编写于 作者: L Lijunhui 提交者: GitHub

[KP] Complete registry of elementwise ops on XPU with KP (#42056)

上级 ba486c5e
...@@ -74,11 +74,12 @@ PD_DECLARE_KERNEL(add, KPS, ALL_LAYOUT); ...@@ -74,11 +74,12 @@ PD_DECLARE_KERNEL(add, KPS, ALL_LAYOUT);
PD_DECLARE_KERNEL(multiply, KPS, ALL_LAYOUT); PD_DECLARE_KERNEL(multiply, KPS, ALL_LAYOUT);
PD_DECLARE_KERNEL(multiply_grad, GPU, ALL_LAYOUT); PD_DECLARE_KERNEL(multiply_grad, GPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(divide, KPS, ALL_LAYOUT); PD_DECLARE_KERNEL(divide, KPS, ALL_LAYOUT);
PD_DECLARE_KERNEL(maximum, GPU, ALL_LAYOUT);
#ifdef PADDLE_WITH_XPU_KP #ifdef PADDLE_WITH_XPU_KP
PD_DECLARE_KERNEL(max_raw, GPU, ALL_LAYOUT); PD_DECLARE_KERNEL(max_raw, GPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(maximum, GPU, ALL_LAYOUT);
#else #else
PD_DECLARE_KERNEL(max_raw, KPS, ALL_LAYOUT); PD_DECLARE_KERNEL(max_raw, KPS, ALL_LAYOUT);
PD_DECLARE_KERNEL(maximum, KPS, ALL_LAYOUT);
#endif #endif
PD_DECLARE_KERNEL(mean, GPU, ALL_LAYOUT); PD_DECLARE_KERNEL(mean, GPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(mean_grad, GPU, ALL_LAYOUT); PD_DECLARE_KERNEL(mean_grad, GPU, ALL_LAYOUT);
......
...@@ -11,6 +11,7 @@ ...@@ -11,6 +11,7 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "paddle/fluid/operators/reduce_ops/reduce_op.cu.h" #include "paddle/fluid/operators/reduce_ops/reduce_op.cu.h"
#include "paddle/fluid/operators/reduce_ops/reduce_op.h" #include "paddle/fluid/operators/reduce_ops/reduce_op.h"
......
...@@ -11,6 +11,7 @@ ...@@ -11,6 +11,7 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "paddle/fluid/operators/reduce_ops/reduce_op.cu.h" #include "paddle/fluid/operators/reduce_ops/reduce_op.cu.h"
#include "paddle/fluid/operators/reduce_ops/reduce_op.h" #include "paddle/fluid/operators/reduce_ops/reduce_op.h"
......
...@@ -29,7 +29,7 @@ limitations under the License. */ ...@@ -29,7 +29,7 @@ limitations under the License. */
#include "paddle/phi/api/lib/utils/tensor_utils.h" #include "paddle/phi/api/lib/utils/tensor_utils.h"
#include "paddle/phi/kernels/cpu/reduce.h" #include "paddle/phi/kernels/cpu/reduce.h"
#if defined(__HIPCC__) || defined(__NVCC__) #if defined(__HIPCC__) || defined(__NVCC__) || defined(__xpu__)
#include "paddle/phi/kernels/gpu/reduce.h" #include "paddle/phi/kernels/gpu/reduce.h"
#include "paddle/phi/kernels/gpu/reduce_grad.h" #include "paddle/phi/kernels/gpu/reduce_grad.h"
#endif #endif
...@@ -613,7 +613,7 @@ If reduce_all is true, just reduce along all dimensions and output a scalar. ...@@ -613,7 +613,7 @@ If reduce_all is true, just reduce along all dimensions and output a scalar.
virtual std::string GetOpType() const = 0; virtual std::string GetOpType() const = 0;
}; };
#if defined(__HIPCC__) || defined(__NVCC__) #if defined(__HIPCC__) || defined(__NVCC__) || defined(__xpu__)
template <typename T, template <typename> class ReduceOp, template <typename T, template <typename> class ReduceOp,
template <typename, typename> class TransformOp> template <typename, typename> class TransformOp>
class ReduceCudaKernel : public framework::OpKernel<T> { class ReduceCudaKernel : public framework::OpKernel<T> {
...@@ -626,9 +626,12 @@ class ReduceCudaKernel : public framework::OpKernel<T> { ...@@ -626,9 +626,12 @@ class ReduceCudaKernel : public framework::OpKernel<T> {
auto pt_out_dtype = paddle::framework::TransToPhiDataType( auto pt_out_dtype = paddle::framework::TransToPhiDataType(
static_cast<framework::proto::VarType::Type>(out_dtype)); static_cast<framework::proto::VarType::Type>(out_dtype));
std::vector<int> dims = context.Attr<std::vector<int>>("dim"); std::vector<int> dims = context.Attr<std::vector<int>>("dim");
#ifdef PADDLE_WITH_XPU_KP
auto& dev_ctx =
context.template device_context<paddle::platform::XPUDeviceContext>();
#else
auto& dev_ctx = context.cuda_device_context(); auto& dev_ctx = context.cuda_device_context();
#endif
if (out_dtype >= 0) { if (out_dtype >= 0) {
output->mutable_data(dev_ctx.GetPlace(), pt_out_dtype); output->mutable_data(dev_ctx.GetPlace(), pt_out_dtype);
} else { } else {
...@@ -642,6 +645,7 @@ class ReduceCudaKernel : public framework::OpKernel<T> { ...@@ -642,6 +645,7 @@ class ReduceCudaKernel : public framework::OpKernel<T> {
} }
}; };
#ifndef PADDLE_WITH_XPU_KP
template <typename T, template <typename, typename> class TransformOp> template <typename T, template <typename, typename> class TransformOp>
class ReduceCudaGradKernel : public framework::OpKernel<T> { class ReduceCudaGradKernel : public framework::OpKernel<T> {
public: public:
...@@ -686,6 +690,7 @@ class ReduceCudaGradKernel : public framework::OpKernel<T> { ...@@ -686,6 +690,7 @@ class ReduceCudaGradKernel : public framework::OpKernel<T> {
} }
}; };
#endif #endif
#endif
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
......
...@@ -42,6 +42,8 @@ XPUOpMap& get_kp_ops() { ...@@ -42,6 +42,8 @@ XPUOpMap& get_kp_ops() {
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"elementwise_floordiv", {"elementwise_floordiv",
XPUKernelSet({pOpKernelType(vartype::INT32, XPUPlace())})}, XPUKernelSet({pOpKernelType(vartype::INT32, XPUPlace())})},
{"elementwise_pow",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
// activation op // activation op
{"exp", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"exp", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"hard_swish", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"hard_swish", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
...@@ -105,6 +107,8 @@ XPUOpMap& get_kp_ops() { ...@@ -105,6 +107,8 @@ XPUOpMap& get_kp_ops() {
{"reduce_prod", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"reduce_prod", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"reduce_all", XPUKernelSet({pOpKernelType(vartype::BOOL, XPUPlace())})}, {"reduce_all", XPUKernelSet({pOpKernelType(vartype::BOOL, XPUPlace())})},
{"reduce_any", XPUKernelSet({pOpKernelType(vartype::BOOL, XPUPlace())})}, {"reduce_any", XPUKernelSet({pOpKernelType(vartype::BOOL, XPUPlace())})},
{"reduce_amax", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"reduce_amin", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
}; };
return s_xpu_kp_kernels; return s_xpu_kp_kernels;
......
...@@ -103,7 +103,7 @@ PD_REGISTER_KERNEL(elementwise_pow, ...@@ -103,7 +103,7 @@ PD_REGISTER_KERNEL(elementwise_pow,
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
PD_REGISTER_KERNEL(maximum, PD_REGISTER_KERNEL(maximum,
GPU, KPS,
ALL_LAYOUT, ALL_LAYOUT,
phi::MaximumKernel, phi::MaximumKernel,
float, float,
...@@ -113,7 +113,7 @@ PD_REGISTER_KERNEL(maximum, ...@@ -113,7 +113,7 @@ PD_REGISTER_KERNEL(maximum,
phi::dtype::float16, phi::dtype::float16,
phi::dtype::bfloat16) {} phi::dtype::bfloat16) {}
PD_REGISTER_KERNEL(minimum, PD_REGISTER_KERNEL(minimum,
GPU, KPS,
ALL_LAYOUT, ALL_LAYOUT,
phi::MinimumKernel, phi::MinimumKernel,
float, float,
...@@ -125,9 +125,9 @@ PD_REGISTER_KERNEL(minimum, ...@@ -125,9 +125,9 @@ PD_REGISTER_KERNEL(minimum,
PD_REGISTER_KERNEL( PD_REGISTER_KERNEL(
modulo, GPU, ALL_LAYOUT, phi::ModuloKernel, float, double, int, int64_t) {} modulo, GPU, ALL_LAYOUT, phi::ModuloKernel, float, double, int, int64_t) {}
PD_REGISTER_KERNEL( PD_REGISTER_KERNEL(
floor_divide, GPU, ALL_LAYOUT, phi::FloorDivideKernel, int, int64_t) {} floor_divide, KPS, ALL_LAYOUT, phi::FloorDivideKernel, int, int64_t) {}
PD_REGISTER_KERNEL(elementwise_pow, PD_REGISTER_KERNEL(elementwise_pow,
GPU, KPS,
ALL_LAYOUT, ALL_LAYOUT,
phi::ElementwisePowKernel, phi::ElementwisePowKernel,
float, float,
......
...@@ -18,6 +18,10 @@ limitations under the License. */ ...@@ -18,6 +18,10 @@ limitations under the License. */
#include "paddle/phi/common/float16.h" #include "paddle/phi/common/float16.h"
#include "paddle/phi/core/enforce.h" #include "paddle/phi/core/enforce.h"
#include "paddle/phi/core/hostdevice.h" #include "paddle/phi/core/hostdevice.h"
#if defined(__xpu__)
#include <xpu/runtime.h>
#include "xpu/kernel/math_xpu2.h" //pow()
#endif
namespace phi { namespace phi {
namespace funcs { namespace funcs {
...@@ -573,6 +577,9 @@ struct ElementwisePowFunctor { ...@@ -573,6 +577,9 @@ struct ElementwisePowFunctor {
return std::llrint( return std::llrint(
std::pow(static_cast<double>(a), static_cast<double>(b))); std::pow(static_cast<double>(a), static_cast<double>(b)));
} }
#endif
#ifdef PADDLE_WITH_XPU_KP
return pow(a, b);
#endif #endif
return std::pow(a, b); return std::pow(a, b);
} }
......
...@@ -36,6 +36,7 @@ void AddKernel(const Context& dev_ctx, ...@@ -36,6 +36,7 @@ void AddKernel(const Context& dev_ctx,
} // namespace phi } // namespace phi
#ifdef PADDLE_WITH_XPU_KP #ifdef PADDLE_WITH_XPU_KP
PD_REGISTER_KERNEL(add, KPS, ALL_LAYOUT, phi::AddKernel, float) {}
PD_REGISTER_KERNEL(add_raw, KPS, ALL_LAYOUT, phi::AddRawKernel, float) {} PD_REGISTER_KERNEL(add_raw, KPS, ALL_LAYOUT, phi::AddRawKernel, float) {}
#else #else
......
...@@ -37,6 +37,7 @@ void DivideKernel(const Context& dev_ctx, ...@@ -37,6 +37,7 @@ void DivideKernel(const Context& dev_ctx,
} // namespace phi } // namespace phi
#ifdef PADDLE_WITH_XPU_KP #ifdef PADDLE_WITH_XPU_KP
PD_REGISTER_KERNEL(divide, KPS, ALL_LAYOUT, phi::DivideKernel, float) {}
PD_REGISTER_KERNEL(divide_raw, KPS, ALL_LAYOUT, phi::DivideRawKernel, float) {} PD_REGISTER_KERNEL(divide_raw, KPS, ALL_LAYOUT, phi::DivideRawKernel, float) {}
#else #else
......
...@@ -24,24 +24,65 @@ namespace phi { ...@@ -24,24 +24,65 @@ namespace phi {
// Create the definition of Maximum // Create the definition of Maximum
DEFINE_CUDA_ELEMENTWISE_OP(Maximum) DEFINE_CUDA_ELEMENTWISE_OP(Maximum)
template <typename T, typename Context>
void MaximumKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
DenseTensor* out) {
int axis = -1;
MaximumRawKernel<T>(dev_ctx, x, y, axis, out);
}
// Create the definition of Minimum // Create the definition of Minimum
DEFINE_CUDA_ELEMENTWISE_OP(Minimum) DEFINE_CUDA_ELEMENTWISE_OP(Minimum)
template <typename T, typename Context>
void MinimumKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
DenseTensor* out) {
int axis = -1;
MinimumRawKernel<T>(dev_ctx, x, y, axis, out);
}
// Create the definition of Modulo // Create the definition of Modulo
DEFINE_CUDA_ELEMENTWISE_OP(Modulo) DEFINE_CUDA_ELEMENTWISE_OP(Modulo)
// Create the definition of FloorDivide // Create the definition of FloorDivide
DEFINE_CUDA_ELEMENTWISE_OP(FloorDivide) DEFINE_CUDA_ELEMENTWISE_OP(FloorDivide)
template <typename T, typename Context>
void FloorDivideKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
DenseTensor* out) {
int axis = -1;
FloorDivideRawKernel<T>(dev_ctx, x, y, axis, out);
}
// Create the definition of Pow // Create the definition of Pow
DEFINE_CUDA_ELEMENTWISE_OP(ElementwisePow) DEFINE_CUDA_ELEMENTWISE_OP(ElementwisePow)
template <typename T, typename Context>
void ElementwisePowKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
DenseTensor* out) {
int axis = -1;
ElementwisePowRawKernel<T>(dev_ctx, x, y, axis, out);
}
} // namespace phi } // namespace phi
#ifdef PADDLE_WITH_XPU_KP #ifdef PADDLE_WITH_XPU_KP
PD_REGISTER_KERNEL(maximum, KPS, ALL_LAYOUT, phi::MaximumKernel, float) {}
PD_REGISTER_KERNEL(maximum_raw, KPS, ALL_LAYOUT, phi::MaximumRawKernel, float) { PD_REGISTER_KERNEL(maximum_raw, KPS, ALL_LAYOUT, phi::MaximumRawKernel, float) {
} }
PD_REGISTER_KERNEL(minimum, KPS, ALL_LAYOUT, phi::MinimumKernel, float) {}
PD_REGISTER_KERNEL(minimum_raw, KPS, ALL_LAYOUT, phi::MinimumRawKernel, float) { PD_REGISTER_KERNEL(minimum_raw, KPS, ALL_LAYOUT, phi::MinimumRawKernel, float) {
} }
PD_REGISTER_KERNEL(floor_divide, KPS, ALL_LAYOUT, phi::FloorDivideKernel, int) {
}
PD_REGISTER_KERNEL( PD_REGISTER_KERNEL(
floor_divide_raw, KPS, ALL_LAYOUT, phi::FloorDivideRawKernel, int) {} floor_divide_raw, KPS, ALL_LAYOUT, phi::FloorDivideRawKernel, int) {}
PD_REGISTER_KERNEL(
elementwise_pow, KPS, ALL_LAYOUT, phi::ElementwisePowKernel, float) {}
PD_REGISTER_KERNEL(
elementwise_pow_raw, KPS, ALL_LAYOUT, phi::ElementwisePowRawKernel, float) {
}
#else #else
using float16 = phi::dtype::float16; using float16 = phi::dtype::float16;
......
...@@ -37,6 +37,7 @@ void MultiplyKernel(const Context& dev_ctx, ...@@ -37,6 +37,7 @@ void MultiplyKernel(const Context& dev_ctx,
} // namespace phi } // namespace phi
#ifdef PADDLE_WITH_XPU_KP #ifdef PADDLE_WITH_XPU_KP
PD_REGISTER_KERNEL(multiply, KPS, ALL_LAYOUT, phi::MultiplyKernel, float) {}
PD_REGISTER_KERNEL( PD_REGISTER_KERNEL(
multiply_raw, KPS, ALL_LAYOUT, phi::MultiplyRawKernel, float) {} multiply_raw, KPS, ALL_LAYOUT, phi::MultiplyRawKernel, float) {}
#else #else
......
...@@ -37,6 +37,7 @@ void SubtractKernel(const Context& dev_ctx, ...@@ -37,6 +37,7 @@ void SubtractKernel(const Context& dev_ctx,
} // namespace phi } // namespace phi
#ifdef PADDLE_WITH_XPU_KP #ifdef PADDLE_WITH_XPU_KP
PD_REGISTER_KERNEL(subtract, KPS, ALL_LAYOUT, phi::SubtractKernel, float) {}
PD_REGISTER_KERNEL( PD_REGISTER_KERNEL(
subtract_raw, KPS, ALL_LAYOUT, phi::SubtractRawKernel, float) {} subtract_raw, KPS, ALL_LAYOUT, phi::SubtractRawKernel, float) {}
#else #else
......
...@@ -65,9 +65,9 @@ void LogicalNotKernel(const Context& dev_ctx, ...@@ -65,9 +65,9 @@ void LogicalNotKernel(const Context& dev_ctx,
#ifdef PADDLE_WITH_XPU_KP #ifdef PADDLE_WITH_XPU_KP
PD_REGISTER_KERNEL(logical_and, KPS, ALL_LAYOUT, phi::LogicalAndKernel, int) {} PD_REGISTER_KERNEL(logical_and, KPS, ALL_LAYOUT, phi::LogicalAndKernel, int) {}
PD_REGISTER_KERNEL(logical_Or, KPS, ALL_LAYOUT, phi::LogicalOrKernel, int) {} PD_REGISTER_KERNEL(logical_or, KPS, ALL_LAYOUT, phi::LogicalOrKernel, int) {}
PD_REGISTER_KERNEL(logical_Not, KPS, ALL_LAYOUT, phi::LogicalNotKernel, int) {} PD_REGISTER_KERNEL(logical_not, KPS, ALL_LAYOUT, phi::LogicalNotKernel, int) {}
PD_REGISTER_KERNEL(logical_Xor, KPS, ALL_LAYOUT, phi::LogicalXorKernel, int) {} PD_REGISTER_KERNEL(logical_xor, KPS, ALL_LAYOUT, phi::LogicalXorKernel, int) {}
#else #else
#define REGISTER_LOGICAL_CUDA_KERNEL(logical_and, func_type) \ #define REGISTER_LOGICAL_CUDA_KERNEL(logical_and, func_type) \
PD_REGISTER_KERNEL(logical_and, \ PD_REGISTER_KERNEL(logical_and, \
......
...@@ -124,7 +124,8 @@ struct MaxFunctor { ...@@ -124,7 +124,8 @@ struct MaxFunctor {
*/ */
template <typename T> template <typename T>
struct AddFunctor { struct AddFunctor {
inline T initial() { return static_cast<T>(0.0f); } inline T initial() { /*return static_cast<T>(0.0f);*/
}
__device__ T operator()(const T a, const T b) const { return b + a; } __device__ T operator()(const T a, const T b) const { return b + a; }
}; };
...@@ -134,7 +135,8 @@ struct AddFunctor { ...@@ -134,7 +135,8 @@ struct AddFunctor {
*/ */
template <typename T> template <typename T>
struct MulFunctor { struct MulFunctor {
inline T initial() { return static_cast<T>(1.0f); } inline T initial() { /*return static_cast<T>(1.0f);*/
}
__device__ T operator()(const T& a, const T& b) const { return b * a; } __device__ T operator()(const T& a, const T& b) const { return b * a; }
}; };
...@@ -144,7 +146,8 @@ struct MulFunctor { ...@@ -144,7 +146,8 @@ struct MulFunctor {
*/ */
template <typename T> template <typename T>
struct LogicalOrFunctor { struct LogicalOrFunctor {
inline T initial() { return static_cast<T>(false); } inline T initial() { /*return static_cast<T>(false);*/
}
__device__ T operator()(const T& a, const T& b) const { return b || a; } __device__ T operator()(const T& a, const T& b) const { return b || a; }
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册