未验证 提交 bb801960 编写于 作者: X Xiaoxu Chen 提交者: GitHub

[phi]migrate fmax,fmin kernel to phi (#40140)

上级 227fa408
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
......@@ -90,86 +87,6 @@ struct MinFunctor {
template <typename T>
using Complex = paddle::platform::complex<T>;
// Fmax
template <typename T>
struct FMaxFunctor {
inline HOSTDEVICE T operator()(const T a, const T b) const {
return std::fmax(a, b);
}
};
template <>
struct FMaxFunctor<paddle::platform::float16> {
inline HOSTDEVICE paddle::platform::float16 operator()(
const paddle::platform::float16 a,
const paddle::platform::float16 b) const {
float float_a = static_cast<float>(a);
float float_b = static_cast<float>(b);
auto result = std::fmax(float_a, float_b);
return static_cast<paddle::platform::float16>(result);
}
};
template <>
struct FMaxFunctor<int> {
inline HOSTDEVICE int operator()(const int a, const int b) const {
float float_a = static_cast<float>(a);
float float_b = static_cast<float>(b);
auto result = std::fmax(float_a, float_b);
return std::lrint(result);
}
};
template <>
struct FMaxFunctor<int64_t> {
inline HOSTDEVICE int64_t operator()(const int64_t a, const int64_t b) const {
double double_a = static_cast<double>(a);
double double_b = static_cast<double>(b);
auto result = std::fmax(double_a, double_b);
return std::llrint(result);
}
};
// Fmin
template <typename T>
struct FMinFunctor {
inline HOSTDEVICE T operator()(const T a, const T b) const {
return std::fmin(a, b);
}
};
template <>
struct FMinFunctor<paddle::platform::float16> {
inline HOSTDEVICE paddle::platform::float16 operator()(
const paddle::platform::float16 a,
const paddle::platform::float16 b) const {
float float_a = static_cast<float>(a);
float float_b = static_cast<float>(b);
auto result = std::fmin(float_a, float_b);
return static_cast<paddle::platform::float16>(result);
}
};
template <>
struct FMinFunctor<int> {
inline HOSTDEVICE int operator()(const int a, const int b) const {
float float_a = static_cast<float>(a);
float float_b = static_cast<float>(b);
auto result = std::fmin(float_a, float_b);
return std::lrint(result);
}
};
template <>
struct FMinFunctor<int64_t> {
inline HOSTDEVICE int64_t operator()(const int64_t a, const int64_t b) const {
double double_a = static_cast<double>(a);
double double_b = static_cast<double>(b);
auto result = std::fmin(double_a, double_b);
return std::llrint(result);
}
};
template <typename T>
struct MinGradXFunctor {
inline HOSTDEVICE T operator()(const T x, const T y, const T dout) const {
......
......@@ -151,21 +151,3 @@ REGISTER_OPERATOR(elementwise_fmax, ops::ElementwiseOp,
ops::ElementwiseFMaxGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(elementwise_fmax_grad, ops::ElementwiseOpGrad);
REGISTER_OP_CPU_KERNEL(
elementwise_fmax,
ops::ElementwiseFMaxKernel<paddle::platform::CPUDeviceContext, float>,
ops::ElementwiseFMaxKernel<paddle::platform::CPUDeviceContext,
paddle::platform::float16>,
ops::ElementwiseFMaxKernel<paddle::platform::CPUDeviceContext, double>,
ops::ElementwiseFMaxKernel<paddle::platform::CPUDeviceContext, int>,
ops::ElementwiseFMaxKernel<paddle::platform::CPUDeviceContext, int64_t>);
REGISTER_OP_CPU_KERNEL(
elementwise_fmax_grad,
ops::ElementwiseFMaxGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::ElementwiseFMaxGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::float16>,
ops::ElementwiseFMaxGradKernel<paddle::platform::CPUDeviceContext, double>,
ops::ElementwiseFMaxGradKernel<paddle::platform::CPUDeviceContext, int>,
ops::ElementwiseFMaxGradKernel<paddle::platform::CPUDeviceContext,
int64_t>);
......@@ -86,21 +86,3 @@ REGISTER_OP_CUDA_KERNEL(
ops::ElementwiseMaxGradKernel<paddle::platform::CUDADeviceContext, int>,
ops::ElementwiseMaxGradKernel<paddle::platform::CUDADeviceContext,
int64_t>);
REGISTER_OP_CUDA_KERNEL(
elementwise_fmax,
ops::ElementwiseFMaxKernel<paddle::platform::CUDADeviceContext, float>,
ops::ElementwiseFMaxKernel<paddle::platform::CUDADeviceContext,
paddle::platform::float16>,
ops::ElementwiseFMaxKernel<paddle::platform::CUDADeviceContext, double>,
ops::ElementwiseFMaxKernel<paddle::platform::CUDADeviceContext, int>,
ops::ElementwiseFMaxKernel<paddle::platform::CUDADeviceContext, int64_t>);
REGISTER_OP_CUDA_KERNEL(
elementwise_fmax_grad,
ops::ElementwiseFMaxGradKernel<paddle::platform::CUDADeviceContext, float>,
ops::ElementwiseFMaxGradKernel<paddle::platform::CUDADeviceContext,
paddle::platform::float16>,
ops::ElementwiseFMaxGradKernel<paddle::platform::CUDADeviceContext, double>,
ops::ElementwiseFMaxGradKernel<paddle::platform::CUDADeviceContext, int>,
ops::ElementwiseFMaxGradKernel<paddle::platform::CUDADeviceContext,
int64_t>);
......@@ -35,21 +35,6 @@ class ElementwiseMaxKernel : public framework::OpKernel<T> {
}
};
template <typename DeviceContext, typename T>
class ElementwiseFMaxKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* x = ctx.Input<framework::LoDTensor>("X");
auto* y = ctx.Input<framework::LoDTensor>("Y");
auto* z = ctx.Output<framework::LoDTensor>("Out");
z->mutable_data<T>(ctx.GetPlace());
int axis = ctx.Attr<int>("axis");
ElementwiseComputeEx<FMaxFunctor<T>, DeviceContext, T>(ctx, x, y, axis,
FMaxFunctor<T>(), z);
}
};
template <typename T>
struct MaxGradDx {
HOSTDEVICE T operator()(T x, T y, T out, T dout) const {
......@@ -104,88 +89,5 @@ class ElementwiseMaxGradKernel : public ElemwiseGradKernel<T> {
}
};
template <typename T>
struct FMaxGradDx {
HOSTDEVICE T operator()(T x, T y, T out, T dout) const {
return dout * static_cast<T>((x >= y) || isnan(y));
}
};
template <>
struct FMaxGradDx<paddle::platform::float16> {
HOSTDEVICE paddle::platform::float16 operator()(
paddle::platform::float16 x, paddle::platform::float16 y,
paddle::platform::float16 out, paddle::platform::float16 dout) const {
return dout * static_cast<paddle::platform::float16>(
(x >= y) || paddle::platform::isnan(y));
}
};
template <>
struct FMaxGradDx<int> {
HOSTDEVICE int operator()(int x, int y, int out, int dout) const {
return dout * static_cast<int>((x >= y));
}
};
template <>
struct FMaxGradDx<int64_t> {
HOSTDEVICE int64_t operator()(int64_t x, int64_t y, int64_t out,
int64_t dout) const {
return dout * static_cast<int64_t>((x >= y));
}
};
template <typename T>
struct FMaxGradDy {
HOSTDEVICE T operator()(T x, T y, T out, T dout) const {
return dout * static_cast<T>(!((x >= y) || isnan(y)));
}
};
template <>
struct FMaxGradDy<paddle::platform::float16> {
HOSTDEVICE paddle::platform::float16 operator()(
paddle::platform::float16 x, paddle::platform::float16 y,
paddle::platform::float16 out, paddle::platform::float16 dout) const {
return dout * static_cast<paddle::platform::float16>(
!((x >= y) || paddle::platform::isnan(y)));
}
};
template <>
struct FMaxGradDy<int64_t> {
HOSTDEVICE int64_t operator()(int64_t x, int64_t y, int64_t out,
int64_t dout) const {
return dout * static_cast<int64_t>(!((x >= y)));
}
};
template <>
struct FMaxGradDy<int> {
HOSTDEVICE int operator()(int x, int y, int out, int dout) const {
return dout * static_cast<int>(!((x >= y)));
}
};
template <typename DeviceContext, typename T>
class ElementwiseFMaxGradKernel : public ElemwiseGradKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
ElemwiseGradKernel<T>::Compute(ctx);
using Tensor = framework::Tensor;
auto* x = ctx.Input<Tensor>("X");
auto* y = ctx.Input<Tensor>("Y");
auto* dout = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
auto* dy = ctx.Output<Tensor>(framework::GradVarName("Y"));
auto* out = dout; // Fake out, not used
int axis = ctx.Attr<int>("axis");
ElemwiseGradCompute<DeviceContext, T, FMaxGradDx<T>, FMaxGradDy<T>>(
ctx, *x, *y, *out, *dout, axis, dx, dy, FMaxGradDx<T>(),
FMaxGradDy<T>());
}
};
} // namespace operators
} // namespace paddle
......@@ -147,21 +147,3 @@ REGISTER_OPERATOR(elementwise_fmin, ops::ElementwiseOp,
ops::ElementwiseFMinGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(elementwise_fmin_grad, ops::ElementwiseOpGrad);
REGISTER_OP_CPU_KERNEL(
elementwise_fmin,
ops::ElementwiseFMinKernel<paddle::platform::CPUDeviceContext, float>,
ops::ElementwiseFMinKernel<paddle::platform::CPUDeviceContext,
paddle::platform::float16>,
ops::ElementwiseFMinKernel<paddle::platform::CPUDeviceContext, double>,
ops::ElementwiseFMinKernel<paddle::platform::CPUDeviceContext, int>,
ops::ElementwiseFMinKernel<paddle::platform::CPUDeviceContext, int64_t>);
REGISTER_OP_CPU_KERNEL(
elementwise_fmin_grad,
ops::ElementwiseFMinGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::ElementwiseFMinGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::float16>,
ops::ElementwiseFMinGradKernel<paddle::platform::CPUDeviceContext, double>,
ops::ElementwiseFMinGradKernel<paddle::platform::CPUDeviceContext, int>,
ops::ElementwiseFMinGradKernel<paddle::platform::CPUDeviceContext,
int64_t>);
......@@ -82,21 +82,3 @@ REGISTER_OP_CUDA_KERNEL(
ops::ElementwiseMinGradKernel<paddle::platform::CUDADeviceContext, int>,
ops::ElementwiseMinGradKernel<paddle::platform::CUDADeviceContext,
int64_t>);
REGISTER_OP_CUDA_KERNEL(
elementwise_fmin,
ops::ElementwiseFMinKernel<paddle::platform::CUDADeviceContext, float>,
ops::ElementwiseFMinKernel<paddle::platform::CUDADeviceContext,
paddle::platform::float16>,
ops::ElementwiseFMinKernel<paddle::platform::CUDADeviceContext, double>,
ops::ElementwiseFMinKernel<paddle::platform::CUDADeviceContext, int>,
ops::ElementwiseFMinKernel<paddle::platform::CUDADeviceContext, int64_t>);
REGISTER_OP_CUDA_KERNEL(
elementwise_fmin_grad,
ops::ElementwiseFMinGradKernel<paddle::platform::CUDADeviceContext, float>,
ops::ElementwiseFMinGradKernel<paddle::platform::CUDADeviceContext,
paddle::platform::float16>,
ops::ElementwiseFMinGradKernel<paddle::platform::CUDADeviceContext, double>,
ops::ElementwiseFMinGradKernel<paddle::platform::CUDADeviceContext, int>,
ops::ElementwiseFMinGradKernel<paddle::platform::CUDADeviceContext,
int64_t>);
......@@ -35,21 +35,6 @@ class ElementwiseMinKernel : public framework::OpKernel<T> {
}
};
template <typename DeviceContext, typename T>
class ElementwiseFMinKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* x = ctx.Input<framework::LoDTensor>("X");
auto* y = ctx.Input<framework::LoDTensor>("Y");
auto* z = ctx.Output<framework::LoDTensor>("Out");
z->mutable_data<T>(ctx.GetPlace());
int axis = ctx.Attr<int>("axis");
ElementwiseComputeEx<FMinFunctor<T>, DeviceContext, T>(ctx, x, y, axis,
FMinFunctor<T>(), z);
}
};
template <typename T>
struct MinGradDx {
HOSTDEVICE T operator()(T x, T y, T out, T dout) const {
......@@ -124,89 +109,5 @@ class ElementwiseMinGradKernel : public ElemwiseGradKernel<T> {
ElementwiseMinGrad<DeviceContext, T>(ctx, x, y, out, dout, dx, dy);
}
};
template <typename T>
struct FMinGradDx {
HOSTDEVICE T operator()(T x, T y, T out, T dout) const {
return dout * static_cast<T>((x <= y) || isnan(y));
}
};
template <>
struct FMinGradDx<paddle::platform::float16> {
HOSTDEVICE paddle::platform::float16 operator()(
paddle::platform::float16 x, paddle::platform::float16 y,
paddle::platform::float16 out, paddle::platform::float16 dout) const {
return dout * static_cast<paddle::platform::float16>(
(x <= y) || paddle::platform::isnan(y));
}
};
template <>
struct FMinGradDx<int> {
HOSTDEVICE int operator()(int x, int y, int out, int dout) const {
return dout * static_cast<int>((x <= y));
}
};
template <>
struct FMinGradDx<int64_t> {
HOSTDEVICE int64_t operator()(int64_t x, int64_t y, int64_t out,
int64_t dout) const {
return dout * static_cast<int64_t>((x <= y));
}
};
template <typename T>
struct FMinGradDy {
HOSTDEVICE T operator()(T x, T y, T out, T dout) const {
return dout * static_cast<T>(!((x <= y) || isnan(y)));
}
};
template <>
struct FMinGradDy<paddle::platform::float16> {
HOSTDEVICE paddle::platform::float16 operator()(
paddle::platform::float16 x, paddle::platform::float16 y,
paddle::platform::float16 out, paddle::platform::float16 dout) const {
return dout * static_cast<paddle::platform::float16>(
!((x <= y) || paddle::platform::isnan(y)));
}
};
template <>
struct FMinGradDy<int> {
HOSTDEVICE int operator()(int x, int y, int out, int dout) const {
return dout * static_cast<int>(!((x <= y)));
}
};
template <>
struct FMinGradDy<int64_t> {
HOSTDEVICE int64_t operator()(int64_t x, int64_t y, int64_t out,
int64_t dout) const {
return dout * static_cast<int64_t>(!((x <= y)));
}
};
template <typename DeviceContext, typename T>
class ElementwiseFMinGradKernel : public ElemwiseGradKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
ElemwiseGradKernel<T>::Compute(ctx);
using Tensor = framework::Tensor;
auto* x = ctx.Input<Tensor>("X");
auto* y = ctx.Input<Tensor>("Y");
auto* dout = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
auto* dy = ctx.Output<Tensor>(framework::GradVarName("Y"));
auto* out = dout; // Fake out, not used
int axis = ctx.Attr<int>("axis");
ElemwiseGradCompute<DeviceContext, T, FMinGradDx<T>, FMinGradDy<T>>(
ctx, *x, *y, *out, *dout, axis, dx, dy, FMinGradDx<T>(),
FMinGradDy<T>());
}
};
} // namespace operators
} // namespace paddle
......@@ -259,3 +259,20 @@ PD_REGISTER_KERNEL(multiply_triple_grad,
phi::dtype::bfloat16,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
PD_REGISTER_KERNEL(elementwise_fmax_grad,
CPU,
ALL_LAYOUT,
phi::ElementwiseFMaxGradKernel,
float,
double,
int,
int64_t) {}
PD_REGISTER_KERNEL(elementwise_fmin_grad,
CPU,
ALL_LAYOUT,
phi::ElementwiseFMinGradKernel,
float,
double,
int,
int64_t) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/elementwise_kernel_impl.h"
PD_REGISTER_KERNEL(elementwise_fmax,
CPU,
ALL_LAYOUT,
phi::ElementwiseFMaxKernel,
float,
double,
int,
int64_t) {}
PD_REGISTER_KERNEL(elementwise_fmin,
CPU,
ALL_LAYOUT,
phi::ElementwiseFMinKernel,
float,
double,
int,
int64_t) {}
......@@ -124,4 +124,22 @@ void MultiplyTripleGradKernel(const Context& dev_ctx,
DenseTensor* d_ddx,
DenseTensor* d_ddy);
template <typename T, typename Context>
void ElementwiseFMaxGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
const DenseTensor& out_grad,
int axis,
DenseTensor* x_grad,
DenseTensor* y_grad);
template <typename T, typename Context>
void ElementwiseFMinGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
const DenseTensor& out_grad,
int axis,
DenseTensor* x_grad,
DenseTensor* y_grad);
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/device_context.h"
namespace phi {
template <typename T, typename Context>
void ElementwiseFMaxKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
int axis,
DenseTensor* out);
template <typename T, typename Context>
void ElementwiseFMinKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
int axis,
DenseTensor* out);
} // namespace phi
......@@ -159,6 +159,219 @@ struct DivGradYFunctor<ComplexType<T>> {
return -a * out_div_c_conj;
}
};
// Fmin
template <typename T>
struct FMinFunctor {
inline HOSTDEVICE T operator()(const T a, const T b) const {
return std::fmin(a, b);
}
};
template <>
struct FMinFunctor<dtype::float16> {
inline HOSTDEVICE dtype::float16 operator()(const dtype::float16 a,
const dtype::float16 b) const {
float float_a = static_cast<float>(a);
float float_b = static_cast<float>(b);
auto result = std::fmin(float_a, float_b);
return static_cast<dtype::float16>(result);
}
};
template <>
struct FMinFunctor<int> {
inline HOSTDEVICE int operator()(const int a, const int b) const {
float float_a = static_cast<float>(a);
float float_b = static_cast<float>(b);
auto result = std::fmin(float_a, float_b);
return std::lrint(result);
}
};
template <>
struct FMinFunctor<int64_t> {
inline HOSTDEVICE int64_t operator()(const int64_t a, const int64_t b) const {
double double_a = static_cast<double>(a);
double double_b = static_cast<double>(b);
auto result = std::fmin(double_a, double_b);
return std::llrint(result);
}
};
// Fmax
template <typename T>
struct FMaxFunctor {
inline HOSTDEVICE T operator()(const T a, const T b) const {
return std::fmax(a, b);
}
};
template <>
struct FMaxFunctor<dtype::float16> {
inline HOSTDEVICE dtype::float16 operator()(const dtype::float16 a,
const dtype::float16 b) const {
float float_a = static_cast<float>(a);
float float_b = static_cast<float>(b);
auto result = std::fmax(float_a, float_b);
return static_cast<dtype::float16>(result);
}
};
template <>
struct FMaxFunctor<int> {
inline HOSTDEVICE int operator()(const int a, const int b) const {
float float_a = static_cast<float>(a);
float float_b = static_cast<float>(b);
auto result = std::fmax(float_a, float_b);
return std::lrint(result);
}
};
template <>
struct FMaxFunctor<int64_t> {
inline HOSTDEVICE int64_t operator()(const int64_t a, const int64_t b) const {
double double_a = static_cast<double>(a);
double double_b = static_cast<double>(b);
auto result = std::fmax(double_a, double_b);
return std::llrint(result);
}
};
template <typename T>
struct FMaxGradDx {
HOSTDEVICE T operator()(T x, T y, T out, T dout) const {
return dout * static_cast<T>((x >= y) || isnan(y));
}
};
template <>
struct FMaxGradDx<dtype::float16> {
HOSTDEVICE dtype::float16 operator()(dtype::float16 x,
dtype::float16 y,
dtype::float16 out,
dtype::float16 dout) const {
return dout * static_cast<dtype::float16>((x >= y) || dtype::isnan(y));
}
};
template <>
struct FMaxGradDx<int> {
HOSTDEVICE int operator()(int x, int y, int out, int dout) const {
return dout * static_cast<int>((x >= y));
}
};
template <>
struct FMaxGradDx<int64_t> {
HOSTDEVICE int64_t operator()(int64_t x,
int64_t y,
int64_t out,
int64_t dout) const {
return dout * static_cast<int64_t>((x >= y));
}
};
template <typename T>
struct FMaxGradDy {
HOSTDEVICE T operator()(T x, T y, T out, T dout) const {
return dout * static_cast<T>(!((x >= y) || isnan(y)));
}
};
template <>
struct FMaxGradDy<dtype::float16> {
HOSTDEVICE dtype::float16 operator()(dtype::float16 x,
dtype::float16 y,
dtype::float16 out,
dtype::float16 dout) const {
return dout * static_cast<dtype::float16>(!((x >= y) || dtype::isnan(y)));
}
};
template <>
struct FMaxGradDy<int64_t> {
HOSTDEVICE int64_t operator()(int64_t x,
int64_t y,
int64_t out,
int64_t dout) const {
return dout * static_cast<int64_t>(!((x >= y)));
}
};
template <>
struct FMaxGradDy<int> {
HOSTDEVICE int operator()(int x, int y, int out, int dout) const {
return dout * static_cast<int>(!((x >= y)));
}
};
template <typename T>
struct FMinGradDx {
HOSTDEVICE T operator()(T x, T y, T out, T dout) const {
return dout * static_cast<T>((x <= y) || isnan(y));
}
};
template <>
struct FMinGradDx<dtype::float16> {
HOSTDEVICE dtype::float16 operator()(dtype::float16 x,
dtype::float16 y,
dtype::float16 out,
dtype::float16 dout) const {
return dout * static_cast<dtype::float16>((x <= y) || dtype::isnan(y));
}
};
template <>
struct FMinGradDx<int> {
HOSTDEVICE int operator()(int x, int y, int out, int dout) const {
return dout * static_cast<int>((x <= y));
}
};
template <>
struct FMinGradDx<int64_t> {
HOSTDEVICE int64_t operator()(int64_t x,
int64_t y,
int64_t out,
int64_t dout) const {
return dout * static_cast<int64_t>((x <= y));
}
};
template <typename T>
struct FMinGradDy {
HOSTDEVICE T operator()(T x, T y, T out, T dout) const {
return dout * static_cast<T>(!((x <= y) || isnan(y)));
}
};
template <>
struct FMinGradDy<dtype::float16> {
HOSTDEVICE dtype::float16 operator()(dtype::float16 x,
dtype::float16 y,
dtype::float16 out,
dtype::float16 dout) const {
return dout * static_cast<dtype::float16>(!((x <= y) || dtype::isnan(y)));
}
};
template <>
struct FMinGradDy<int> {
HOSTDEVICE int operator()(int x, int y, int out, int dout) const {
return dout * static_cast<int>(!((x <= y)));
}
};
template <>
struct FMinGradDy<int64_t> {
HOSTDEVICE int64_t operator()(int64_t x,
int64_t y,
int64_t out,
int64_t dout) const {
return dout * static_cast<int64_t>(!((x <= y)));
}
};
template <typename T>
struct MultiplyGradFunctor {
......
......@@ -282,3 +282,20 @@ PD_REGISTER_KERNEL(multiply_triple_grad,
phi::dtype::bfloat16,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
PD_REGISTER_KERNEL(elementwise_fmax_grad,
GPU,
ALL_LAYOUT,
phi::ElementwiseFMaxGradKernel,
float,
double,
int,
int64_t) {}
PD_REGISTER_KERNEL(elementwise_fmin_grad,
GPU,
ALL_LAYOUT,
phi::ElementwiseFMinGradKernel,
float,
double,
int,
int64_t) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/elementwise_kernel_impl.h"
PD_REGISTER_KERNEL(elementwise_fmax,
GPU,
ALL_LAYOUT,
phi::ElementwiseFMaxKernel,
float,
double,
int,
int64_t) {}
PD_REGISTER_KERNEL(elementwise_fmin,
GPU,
ALL_LAYOUT,
phi::ElementwiseFMinKernel,
float,
double,
int,
int64_t) {}
......@@ -258,6 +258,102 @@ void DivideDoubleGradKernel(const Context& dev_ctx,
dout_result.device(place) = static_cast<T>(-1) * dout_result;
}
}
template <typename T, typename Context>
void ElementwiseFMaxGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
const DenseTensor& out_grad,
int axis,
DenseTensor* x_grad,
DenseTensor* y_grad) {
funcs::ElementwiseGradPreProcess(out_grad, x_grad);
auto out = out_grad; // Fake out, not used
auto x_dim = x.dims();
auto y_dim = y.dims();
if (x.dims() == y.dims()) {
funcs::ElemwiseGradComputeNoBroadcast<Context,
T,
funcs::FMaxGradDx<T>,
funcs::FMaxGradDy<T>>(
dev_ctx,
x_dim,
y_dim,
x,
y,
out,
out_grad,
axis,
x_grad,
y_grad,
funcs::FMaxGradDx<T>(),
funcs::FMaxGradDy<T>());
} else {
funcs::ElemwiseGradComputeWithBroadcast<T,
funcs::FMaxGradDx<T>,
funcs::FMaxGradDy<T>>(
dev_ctx,
x_dim,
y_dim,
x,
y,
out,
out_grad,
axis,
x_grad,
y_grad,
funcs::FMaxGradDx<T>(),
funcs::FMaxGradDy<T>());
}
}
template <typename T, typename Context>
void ElementwiseFMinGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
const DenseTensor& out_grad,
int axis,
DenseTensor* x_grad,
DenseTensor* y_grad) {
funcs::ElementwiseGradPreProcess(out_grad, x_grad);
auto out = out_grad; // Fake out, not used
auto x_dim = x.dims();
auto y_dim = y.dims();
if (x.dims() == y.dims()) {
funcs::ElemwiseGradComputeNoBroadcast<Context,
T,
funcs::FMinGradDx<T>,
funcs::FMinGradDy<T>>(
dev_ctx,
x_dim,
y_dim,
x,
y,
out,
out_grad,
axis,
x_grad,
y_grad,
funcs::FMinGradDx<T>(),
funcs::FMinGradDy<T>());
} else {
funcs::ElemwiseGradComputeWithBroadcast<T,
funcs::FMinGradDx<T>,
funcs::FMinGradDy<T>>(
dev_ctx,
x_dim,
y_dim,
x,
y,
out,
out_grad,
axis,
x_grad,
y_grad,
funcs::FMinGradDx<T>(),
funcs::FMinGradDy<T>());
}
}
template <typename T>
struct MulGradDX {
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/kernels/elementwise_kernel.h"
#include "paddle/phi/kernels/funcs/elementwise_base.h"
#include "paddle/phi/kernels/funcs/elementwise_functor.h"
#if defined(__NVCC__) || defined(__HIPCC__)
#include "paddle/phi/kernels/funcs/broadcast_function.h"
#endif
namespace phi {
template <typename T, typename Context>
void ElementwiseFMaxKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
int axis,
DenseTensor* out) {
dev_ctx.template Alloc<T>(out);
funcs::ElementwiseCompute<funcs::FMaxFunctor<T>, T, T>(
dev_ctx, x, y, axis, funcs::FMaxFunctor<T>(), out);
}
template <typename T, typename Context>
void ElementwiseFMinKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
int axis,
DenseTensor* out) {
dev_ctx.template Alloc<T>(out);
funcs::ElementwiseCompute<funcs::FMinFunctor<T>, T, T>(
dev_ctx, x, y, axis, funcs::FMinFunctor<T>(), out);
}
} // namespace phi
......@@ -114,6 +114,14 @@ KernelSignature ElementwiseDivGradOpArgumentMapping(
{GradVarName("X"), GradVarName("Y")});
}
KernelSignature ElementwiseFMinGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature("elementwise_fmin_grad",
{"X", "Y", GradVarName("Out")},
{"axis"},
{GradVarName("X"), GradVarName("Y")});
}
KernelSignature ElementwiseDivDoubleGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature("divide_double_grad",
......@@ -130,6 +138,14 @@ KernelSignature ElementwiseMulGradOpArgumentMapping(
{GradVarName("X"), GradVarName("Y")});
}
KernelSignature ElementwiseFMaxGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature("elementwise_fmax_grad",
{"X", "Y", GradVarName("Out")},
{"axis"},
{GradVarName("X"), GradVarName("Y")});
}
KernelSignature ElementwiseMulDoubleGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature("multiply_double_grad",
......@@ -192,3 +208,9 @@ PD_REGISTER_ARG_MAPPING_FN(elementwise_mul_grad_grad,
phi::ElementwiseMulDoubleGradOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(elementwise_mul_triple_grad,
phi::ElementwiseMulTripleGradOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(elementwise_fmax_grad,
phi::ElementwiseFMaxGradOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(elementwise_fmin_grad,
phi::ElementwiseFMinGradOpArgumentMapping);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册