未验证 提交 cfadf61b 编写于 作者: F FlyingQianMM 提交者: GitHub

move elementwise_max/min/mod into phi (#40590)

上级 3228fc34
......@@ -54,7 +54,7 @@ USE_OP(sum);
USE_OP_ITSELF(slice_grad);
USE_OP_ITSELF(lookup_table_grad);
USE_OP(sqrt);
USE_OP(elementwise_max);
USE_OP_ITSELF(elementwise_max);
USE_OP_ITSELF(elementwise_div);
USE_OP_ITSELF(sgd);
USE_OP(squared_l2_norm);
......
......@@ -70,75 +70,29 @@ struct InverseFloorDivFunctor {
// Maximum
template <typename T>
struct MaxFunctor {
inline HOSTDEVICE T operator()(const T a, const T b) const {
return a > b ? a : b;
}
};
using MaxFunctor = phi::funcs::MaximumFunctor<T>;
// Minmum
template <typename T>
struct MinFunctor {
inline HOSTDEVICE T operator()(const T a, const T b) const {
return a < b ? a : b;
}
};
using MinFunctor = phi::funcs::MinimumFunctor<T>;
template <typename T>
using Complex = paddle::platform::complex<T>;
// Ternary compare
template <typename T>
struct MinGradXFunctor {
inline HOSTDEVICE T operator()(const T x, const T y, const T dout) const {
return dout * static_cast<T>(x < y);
}
};
using MaxGradXFunctor = phi::funcs::MaxGradXFunctor<T>;
template <typename T>
struct MinGradYFunctor {
inline HOSTDEVICE T operator()(const T x, const T y, const T dout) const {
return dout * static_cast<T>(x >= y);
}
};
using MaxGradYFunctor = phi::funcs::MaxGradYFunctor<T>;
template <typename InT, typename OutT>
struct MinGradXYFunctor {
inline HOSTDEVICE phi::Array<OutT, 2> operator()(const InT x, const InT y,
const InT dout) {
phi::Array<OutT, 2> outs;
// dx = dout * (x < y)
outs[0] = static_cast<OutT>(dout * static_cast<InT>(x < y));
// dy = dout * (x >= y)
outs[1] = static_cast<OutT>(dout * static_cast<InT>(x >= y));
return outs;
}
};
using MaxGradXYFunctor = phi::funcs::MaxGradXYFunctor<InT, OutT>;
// Ternary compare
template <typename T>
struct MaxGradXFunctor {
inline HOSTDEVICE T operator()(const T x, const T y, const T dout) const {
return dout * static_cast<T>(x > y);
}
};
using MinGradXFunctor = phi::funcs::MinGradXFunctor<T>;
template <typename T>
struct MaxGradYFunctor {
inline HOSTDEVICE T operator()(const T x, const T y, const T dout) const {
return dout * static_cast<T>(x <= y);
}
};
using MinGradYFunctor = phi::funcs::MinGradYFunctor<T>;
template <typename InT, typename OutT>
struct MaxGradXYFunctor {
inline HOSTDEVICE phi::Array<OutT, 2> operator()(const InT x, const InT y,
const InT dout) {
phi::Array<OutT, 2> outs;
// dx = dout * (x > y)
outs[0] = static_cast<OutT>(dout * static_cast<InT>(x > y));
// dy = dout * (x <= y)
outs[1] = static_cast<OutT>(dout * static_cast<InT>(x <= y));
return outs;
}
};
using MinGradXYFunctor = phi::funcs::MinGradXYFunctor<InT, OutT>;
} // namespace operators
} // namespace paddle
......@@ -12,8 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/elementwise/elementwise_max_op.h"
#include <string>
#include "paddle/fluid/operators/elementwise/elementwise_op.h"
......@@ -119,23 +117,6 @@ REGISTER_OPERATOR(elementwise_max, ops::ElementwiseOp,
REGISTER_OPERATOR(elementwise_max_grad, ops::ElementwiseOpGrad);
REGISTER_OP_CPU_KERNEL(
elementwise_max,
ops::ElementwiseMaxKernel<paddle::platform::CPUDeviceContext, float>,
ops::ElementwiseMaxKernel<paddle::platform::CPUDeviceContext, double>,
ops::ElementwiseMaxKernel<paddle::platform::CPUDeviceContext, int>,
ops::ElementwiseMaxKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::ElementwiseMaxKernel<paddle::platform::CPUDeviceContext,
paddle::platform::bfloat16>);
REGISTER_OP_CPU_KERNEL(
elementwise_max_grad,
ops::ElementwiseMaxGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::ElementwiseMaxGradKernel<paddle::platform::CPUDeviceContext, double>,
ops::ElementwiseMaxGradKernel<paddle::platform::CPUDeviceContext, int>,
ops::ElementwiseMaxGradKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::ElementwiseMaxGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::bfloat16>);
REGISTER_OP_VERSION(elementwise_max)
.AddCheckpoint(
R"ROC(Register elementwise_max for adding the attribute of Scale_y)ROC",
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/elementwise/elementwise_max_op.h"
namespace paddle {
namespace operators {
template <typename T>
class ElementwiseMaxKernel<platform::CUDADeviceContext, T>
: public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
std::vector<const framework::Tensor*> ins;
std::vector<framework::Tensor*> outs;
const auto& dev_ctx =
ctx.template device_context<platform::CUDADeviceContext>();
int axis = PackTensorsIntoVector<T>(ctx, &ins, &outs);
paddle::operators::LaunchElementwiseCudaKernel<ElementwiseType::kBinary, T,
T>(dev_ctx, ins, &outs, axis,
MaxFunctor<T>());
}
};
template <typename DeviceContext, typename T>
typename std::enable_if<
std::is_same<DeviceContext, platform::CUDADeviceContext>::value>::type
ElementwiseMaxGrad(const framework::ExecutionContext& ctx,
const framework::Tensor* x, const framework::Tensor* y,
const framework::Tensor* out, const framework::Tensor* dout,
framework::Tensor* dx, framework::Tensor* dy) {
int axis = ctx.Attr<int>("axis");
const auto& dev_ctx =
ctx.template device_context<platform::CUDADeviceContext>();
const auto place = ctx.GetPlace();
if (dx != nullptr && dy != nullptr) {
std::vector<const framework::Tensor*> ins = {x, y, dout};
GetGradXAndYOut<ElementwiseType::kTernary, T>(
dev_ctx, place, axis, ins, dout, dx, dy, MaxGradXYFunctor<T, T>());
} else if (dx != nullptr && dy == nullptr) {
std::vector<const framework::Tensor*> ins = {x, y, dout};
GetGradXOrYOut<ElementwiseType::kTernary, T>(
dev_ctx, place, axis, ins, dout, dx, MaxGradXFunctor<T>());
} else if (dx == nullptr && dy != nullptr) {
std::vector<const framework::Tensor*> ins = {x, y, dout};
GetGradXOrYOut<ElementwiseType::kTernary, T>(
dev_ctx, place, axis, ins, dout, dy, MaxGradYFunctor<T>());
}
}
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
elementwise_max,
ops::ElementwiseMaxKernel<paddle::platform::CUDADeviceContext,
paddle::platform::float16>,
ops::ElementwiseMaxKernel<paddle::platform::CUDADeviceContext,
paddle::platform::bfloat16>,
ops::ElementwiseMaxKernel<paddle::platform::CUDADeviceContext, float>,
ops::ElementwiseMaxKernel<paddle::platform::CUDADeviceContext, double>,
ops::ElementwiseMaxKernel<paddle::platform::CUDADeviceContext, int>,
ops::ElementwiseMaxKernel<paddle::platform::CUDADeviceContext, int64_t>);
REGISTER_OP_CUDA_KERNEL(
elementwise_max_grad,
ops::ElementwiseMaxGradKernel<paddle::platform::CUDADeviceContext,
paddle::platform::float16>,
ops::ElementwiseMaxGradKernel<paddle::platform::CUDADeviceContext,
paddle::platform::bfloat16>,
ops::ElementwiseMaxGradKernel<paddle::platform::CUDADeviceContext, float>,
ops::ElementwiseMaxGradKernel<paddle::platform::CUDADeviceContext, double>,
ops::ElementwiseMaxGradKernel<paddle::platform::CUDADeviceContext, int>,
ops::ElementwiseMaxGradKernel<paddle::platform::CUDADeviceContext,
int64_t>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <cmath>
#include "paddle/fluid/operators/elementwise/elementwise_op.h"
namespace paddle {
namespace operators {
template <typename DeviceContext, typename T>
class ElementwiseMaxKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* x = ctx.Input<framework::LoDTensor>("X");
auto* y = ctx.Input<framework::LoDTensor>("Y");
auto* z = ctx.Output<framework::LoDTensor>("Out");
z->mutable_data<T>(ctx.GetPlace());
int axis = ctx.Attr<int>("axis");
ElementwiseComputeEx<MaxFunctor<T>, DeviceContext, T>(ctx, x, y, axis,
MaxFunctor<T>(), z);
}
};
template <typename T>
struct MaxGradDx {
HOSTDEVICE T operator()(T x, T y, T out, T dout) const {
return dout * static_cast<T>(x > y);
}
};
template <typename T>
struct MaxGradDy {
HOSTDEVICE T operator()(T x, T y, T out, T dout) const {
return dout * static_cast<T>(x <= y);
}
};
template <typename DeviceContext, typename T>
typename std::enable_if<
std::is_same<DeviceContext, platform::CPUDeviceContext>::value>::type
ElementwiseMaxGrad(const framework::ExecutionContext& ctx,
const framework::Tensor* x, const framework::Tensor* y,
const framework::Tensor* out, const framework::Tensor* dout,
framework::Tensor* dx, framework::Tensor* dy) {
int axis = ctx.Attr<int>("axis");
ElemwiseGradCompute<DeviceContext, T, MaxGradDx<T>, MaxGradDy<T>>(
ctx, *x, *y, *out, *dout, axis, dx, dy, MaxGradDx<T>(), MaxGradDy<T>());
}
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
template <typename DeviceContext, typename T>
typename std::enable_if<
std::is_same<DeviceContext, platform::CUDADeviceContext>::value>::type
ElementwiseMaxGrad(const framework::ExecutionContext& ctx,
const framework::Tensor* x, const framework::Tensor* y,
const framework::Tensor* out, const framework::Tensor* dout,
framework::Tensor* dx, framework::Tensor* dy);
#endif
template <typename DeviceContext, typename T>
class ElementwiseMaxGradKernel : public ElemwiseGradKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
ElemwiseGradKernel<T>::Compute(ctx);
using Tensor = framework::Tensor;
auto* x = ctx.Input<Tensor>("X");
auto* y = ctx.Input<Tensor>("Y");
auto* dout = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto* out = dout; // out is not necessary
auto* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
auto* dy = ctx.Output<Tensor>(framework::GradVarName("Y"));
ElementwiseMaxGrad<DeviceContext, T>(ctx, x, y, out, dout, dx, dy);
}
};
} // namespace operators
} // namespace paddle
......@@ -12,7 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/elementwise/elementwise_max_op.h"
#include "paddle/fluid/operators/elementwise/elementwise_npu.h"
#include "paddle/fluid/platform/device/npu/npu_op_runner.h"
......
......@@ -14,7 +14,6 @@ limitations under the License. */
#ifdef PADDLE_WITH_XPU
#include "paddle/fluid/operators/elementwise/elementwise_max_op.h"
#include "paddle/fluid/operators/elementwise/elementwise_op.h"
#include "paddle/fluid/operators/elementwise/elementwise_xpu.h"
namespace paddle {
......
......@@ -12,8 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/elementwise/elementwise_min_op.h"
#include <string>
#include "paddle/fluid/operators/elementwise/elementwise_op.h"
......@@ -119,19 +117,6 @@ REGISTER_OPERATOR(elementwise_min, ops::ElementwiseOp,
REGISTER_OPERATOR(elementwise_min_grad, ops::ElementwiseOpGrad);
REGISTER_OP_CPU_KERNEL(
elementwise_min,
ops::ElementwiseMinKernel<paddle::platform::CPUDeviceContext, float>,
ops::ElementwiseMinKernel<paddle::platform::CPUDeviceContext, double>,
ops::ElementwiseMinKernel<paddle::platform::CPUDeviceContext, int>,
ops::ElementwiseMinKernel<paddle::platform::CPUDeviceContext, int64_t>);
REGISTER_OP_CPU_KERNEL(
elementwise_min_grad,
ops::ElementwiseMinGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::ElementwiseMinGradKernel<paddle::platform::CPUDeviceContext, double>,
ops::ElementwiseMinGradKernel<paddle::platform::CPUDeviceContext, int>,
ops::ElementwiseMinGradKernel<paddle::platform::CPUDeviceContext, int64_t>);
REGISTER_OP_VERSION(elementwise_min)
.AddCheckpoint(
R"ROC(Register elementwise_min for adding the attribute of Scale_y)ROC",
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/elementwise/elementwise_min_op.h"
namespace paddle {
namespace operators {
template <typename T>
class ElementwiseMinKernel<platform::CUDADeviceContext, T>
: public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
std::vector<const framework::Tensor*> ins;
std::vector<framework::Tensor*> outs;
const auto& dev_ctx =
ctx.template device_context<platform::CUDADeviceContext>();
int axis = PackTensorsIntoVector<T>(ctx, &ins, &outs);
paddle::operators::LaunchElementwiseCudaKernel<ElementwiseType::kBinary, T,
T>(dev_ctx, ins, &outs, axis,
MinFunctor<T>());
}
};
template <typename DeviceContext, typename T>
typename std::enable_if<
std::is_same<DeviceContext, platform::CUDADeviceContext>::value>::type
ElementwiseMinGrad(const framework::ExecutionContext& ctx,
const framework::Tensor* x, const framework::Tensor* y,
const framework::Tensor* out, const framework::Tensor* dout,
framework::Tensor* dx, framework::Tensor* dy) {
int axis = ctx.Attr<int>("axis");
const auto& dev_ctx =
ctx.template device_context<platform::CUDADeviceContext>();
const auto place = ctx.GetPlace();
if (dx != nullptr && dy != nullptr) {
std::vector<const framework::Tensor*> ins = {x, y, dout};
GetGradXAndYOut<ElementwiseType::kTernary, T>(
dev_ctx, place, axis, ins, dout, dx, dy, MinGradXYFunctor<T, T>());
} else if (dx != nullptr && dy == nullptr) {
std::vector<const framework::Tensor*> ins = {x, y, dout};
GetGradXOrYOut<ElementwiseType::kTernary, T>(
dev_ctx, place, axis, ins, dout, dx, MinGradXFunctor<T>());
} else if (dx == nullptr && dy != nullptr) {
std::vector<const framework::Tensor*> ins = {x, y, dout};
GetGradXOrYOut<ElementwiseType::kTernary, T>(
dev_ctx, place, axis, ins, dout, dy, MinGradYFunctor<T>());
}
}
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
elementwise_min,
ops::ElementwiseMinKernel<paddle::platform::CUDADeviceContext,
paddle::platform::float16>,
ops::ElementwiseMinKernel<paddle::platform::CUDADeviceContext, float>,
ops::ElementwiseMinKernel<paddle::platform::CUDADeviceContext, double>,
ops::ElementwiseMinKernel<paddle::platform::CUDADeviceContext, int>,
ops::ElementwiseMinKernel<paddle::platform::CUDADeviceContext, int64_t>);
REGISTER_OP_CUDA_KERNEL(
elementwise_min_grad,
ops::ElementwiseMinGradKernel<paddle::platform::CUDADeviceContext,
paddle::platform::float16>,
ops::ElementwiseMinGradKernel<paddle::platform::CUDADeviceContext, float>,
ops::ElementwiseMinGradKernel<paddle::platform::CUDADeviceContext, double>,
ops::ElementwiseMinGradKernel<paddle::platform::CUDADeviceContext, int>,
ops::ElementwiseMinGradKernel<paddle::platform::CUDADeviceContext,
int64_t>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <cmath>
#include "paddle/fluid/operators/elementwise/elementwise_op.h"
namespace paddle {
namespace operators {
template <typename DeviceContext, typename T>
class ElementwiseMinKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* x = ctx.Input<framework::LoDTensor>("X");
auto* y = ctx.Input<framework::LoDTensor>("Y");
auto* z = ctx.Output<framework::LoDTensor>("Out");
z->mutable_data<T>(ctx.GetPlace());
int axis = ctx.Attr<int>("axis");
ElementwiseComputeEx<MinFunctor<T>, DeviceContext, T>(ctx, x, y, axis,
MinFunctor<T>(), z);
}
};
template <typename T>
struct MinGradDx {
HOSTDEVICE T operator()(T x, T y, T out, T dout) const {
return dout * (x < y);
}
};
template <typename T>
struct MinGradDy {
HOSTDEVICE T operator()(T x, T y, T out, T dout) const {
return dout * (x >= y);
}
};
#ifdef PADDLE_CUDA_FP16
template <>
struct MinGradDx<platform::float16> {
HOSTDEVICE platform::float16 operator()(platform::float16 x,
platform::float16 y,
platform::float16 out,
platform::float16 dout) const {
return x < y ? dout : static_cast<platform::float16>(0);
}
};
template <>
struct MinGradDy<platform::float16> {
HOSTDEVICE platform::float16 operator()(platform::float16 x,
platform::float16 y,
platform::float16 out,
platform::float16 dout) const {
return x >= y ? dout : static_cast<platform::float16>(0);
}
};
#endif
template <typename DeviceContext, typename T>
typename std::enable_if<
std::is_same<DeviceContext, platform::CPUDeviceContext>::value>::type
ElementwiseMinGrad(const framework::ExecutionContext& ctx,
const framework::Tensor* x, const framework::Tensor* y,
const framework::Tensor* out, const framework::Tensor* dout,
framework::Tensor* dx, framework::Tensor* dy) {
int axis = ctx.Attr<int>("axis");
ElemwiseGradCompute<DeviceContext, T, MinGradDx<T>, MinGradDy<T>>(
ctx, *x, *y, *out, *dout, axis, dx, dy, MinGradDx<T>(), MinGradDy<T>());
}
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
template <typename DeviceContext, typename T>
typename std::enable_if<
std::is_same<DeviceContext, platform::CUDADeviceContext>::value>::type
ElementwiseMinGrad(const framework::ExecutionContext& ctx,
const framework::Tensor* x, const framework::Tensor* y,
const framework::Tensor* out, const framework::Tensor* dout,
framework::Tensor* dx, framework::Tensor* dy);
#endif
template <typename DeviceContext, typename T>
class ElementwiseMinGradKernel : public ElemwiseGradKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
ElemwiseGradKernel<T>::Compute(ctx);
using Tensor = framework::Tensor;
auto* x = ctx.Input<Tensor>("X");
auto* y = ctx.Input<Tensor>("Y");
auto* dout = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
auto* dy = ctx.Output<Tensor>(framework::GradVarName("Y"));
auto* out = dout; // Fake out, not used
ElementwiseMinGrad<DeviceContext, T>(ctx, x, y, out, dout, dx, dy);
}
};
} // namespace operators
} // namespace paddle
......@@ -16,7 +16,6 @@ limitations under the License. */
#include <string>
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/operators/elementwise/elementwise_min_op.h"
#include "paddle/fluid/operators/elementwise/elementwise_npu.h"
#include "paddle/fluid/platform/device/npu/npu_op_runner.h"
......
......@@ -14,7 +14,6 @@ limitations under the License. */
#ifdef PADDLE_WITH_XPU
#include "paddle/fluid/operators/elementwise/elementwise_max_op.h"
#include "paddle/fluid/operators/elementwise/elementwise_op.h"
#include "paddle/fluid/operators/elementwise/elementwise_xpu.h"
namespace paddle {
......
......@@ -12,8 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/elementwise/elementwise_mod_op.h"
#include <string>
#include "paddle/fluid/operators/elementwise/elementwise_op.h"
......@@ -62,13 +60,6 @@ namespace ops = paddle::operators;
REGISTER_OP_WITHOUT_GRADIENT(elementwise_mod, ops::ElementwiseOp,
ops::ElementwiseModOpMaker);
REGISTER_OP_CPU_KERNEL(
elementwise_mod,
ops::ElementwiseModKernel<paddle::platform::CPUDeviceContext, int>,
ops::ElementwiseModKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::ElementwiseModKernel<paddle::platform::CPUDeviceContext, float>,
ops::ElementwiseModKernel<paddle::platform::CPUDeviceContext, double>);
REGISTER_OP_VERSION(elementwise_mod)
.AddCheckpoint(
R"ROC(Register elementwise_mod for adding the attribute of Scale_y)ROC",
......
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/elementwise/elementwise_mod_op.h"
namespace paddle {
namespace operators {
template <typename T>
class ElementwiseModKernel<platform::CUDADeviceContext, T>
: public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
std::vector<const framework::Tensor*> ins;
std::vector<framework::Tensor*> outs;
const auto& cuda_ctx =
ctx.template device_context<platform::CUDADeviceContext>();
int axis = PackTensorsIntoVector<T>(ctx, &ins, &outs);
paddle::operators::LaunchElementwiseCudaKernel<ElementwiseType::kBinary, T,
T>(cuda_ctx, ins, &outs,
axis, ModFunctor<T>());
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(
elementwise_mod, ops::ElementwiseModKernel<plat::CUDADeviceContext, int>,
ops::ElementwiseModKernel<plat::CUDADeviceContext, int64_t>,
ops::ElementwiseModKernel<plat::CUDADeviceContext, float>,
ops::ElementwiseModKernel<plat::CUDADeviceContext, double>);
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/operators/elementwise/elementwise_op.h"
namespace paddle {
namespace operators {
template <typename T, typename Enable = void>
struct ModFunctor {
inline HOSTDEVICE T operator()(const T a, const T b) const {
T res = a % b;
// Accoding to #PR26732: in dividen % divsor
// remainder shall have the same sign as divsor.
if ((res != 0) && ((b ^ res) < 0)) res += b;
return res;
}
};
template <typename T>
struct ModFunctor<T,
typename std::enable_if_t<std::is_floating_point<T>::value>> {
inline HOSTDEVICE T operator()(const T a, const T b) const {
T res = fmod(a, b);
// Accoding to #PR26732: in dividen % divsor
// remainder shall have the same sign as divsor.
if ((res != 0) && ((res < 0) != (b < 0))) res += b;
return res;
}
};
template <typename T, typename Enable = void>
struct InverseModFunctor {
inline HOSTDEVICE T operator()(const T a, const T b) const {
T res = b % a;
if ((res != 0) && ((res < 0) != (a < 0))) res += a;
return res;
}
};
template <typename T>
struct InverseModFunctor<
T, typename std::enable_if_t<std::is_floating_point<T>::value>> {
inline HOSTDEVICE T operator()(const T a, const T b) const {
T res = fmod(b, a);
if ((res != 0) && ((a < 0) != (res < 0))) res += a;
return res;
}
};
template <typename DeviceContext, typename T>
void elementwise_mod(const framework::ExecutionContext &ctx,
const framework::Tensor *x, const framework::Tensor *y,
framework::Tensor *z) {
int axis = ctx.Attr<int>("axis");
auto x_dims = x->dims();
auto y_dims = y->dims();
if (x_dims.size() >= y_dims.size()) {
ElementwiseComputeEx<ModFunctor<T>, DeviceContext, T>(ctx, x, y, axis,
ModFunctor<T>(), z);
} else {
ElementwiseComputeEx<InverseModFunctor<T>, DeviceContext, T>(
ctx, x, y, axis, InverseModFunctor<T>(), z);
}
}
template <typename DeviceContext, typename T>
class ElementwiseModKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
auto *x = ctx.Input<framework::LoDTensor>("X");
auto *y = ctx.Input<framework::LoDTensor>("Y");
auto *z = ctx.Output<framework::LoDTensor>("Out");
z->mutable_data<T>(ctx.GetPlace());
// dtype of x and y is int64 or int32
elementwise_mod<DeviceContext, T>(ctx, x, y, z);
}
};
} // namespace operators
} // namespace paddle
......@@ -12,7 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/elementwise/elementwise_mod_op.h"
#include "paddle/fluid/operators/elementwise/elementwise_npu.h"
#include "paddle/fluid/platform/device/npu/npu_op_runner.h"
......
......@@ -135,6 +135,32 @@ void MultiplyGradKernel(const Context& dev_ctx,
dev_ctx, x, y, *out, dout, axis, dx, dy, MulGradDX<T>(), MulGradDY<T>());
}
template <typename T, typename Context>
void MaximumGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
const DenseTensor& dout,
int axis,
DenseTensor* dx,
DenseTensor* dy) {
funcs::ElementwiseGradPreProcess(dout, dx);
phi::funcs::ElemwiseGradCompute<Context, T, MaxGradDx<T>, MaxGradDy<T>>(
dev_ctx, x, y, dout, dout, axis, dx, dy, MaxGradDx<T>(), MaxGradDy<T>());
}
template <typename T, typename Context>
void MinimumGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
const DenseTensor& dout,
int axis,
DenseTensor* dx,
DenseTensor* dy) {
funcs::ElementwiseGradPreProcess(dout, dx);
phi::funcs::ElemwiseGradCompute<Context, T, MinGradDx<T>, MinGradDy<T>>(
dev_ctx, x, y, dout, dout, axis, dx, dy, MinGradDx<T>(), MinGradDy<T>());
}
} // namespace phi
PD_REGISTER_KERNEL(add_grad,
......@@ -259,6 +285,7 @@ PD_REGISTER_KERNEL(multiply_triple_grad,
phi::dtype::bfloat16,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
PD_REGISTER_KERNEL(fmax_grad,
CPU,
ALL_LAYOUT,
......@@ -276,3 +303,23 @@ PD_REGISTER_KERNEL(fmin_grad,
double,
int,
int64_t) {}
PD_REGISTER_KERNEL(maximum_grad,
CPU,
ALL_LAYOUT,
phi::MaximumGradKernel,
float,
double,
int,
int64_t,
phi::dtype::bfloat16) {}
PD_REGISTER_KERNEL(minimum_grad,
CPU,
ALL_LAYOUT,
phi::MinimumGradKernel,
float,
double,
int,
int64_t,
phi::dtype::bfloat16) {}
......@@ -70,6 +70,49 @@ void DivideRawKernel(const Context& dev_ctx,
}
}
template <typename T, typename Context>
void MaximumRawKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
int axis,
DenseTensor* out) {
// allocate memory for out
dev_ctx.template Alloc<T>(out);
funcs::ElementwiseCompute<funcs::MaximumFunctor<T>, T>(
dev_ctx, x, y, axis, funcs::MaximumFunctor<T>(), out);
}
template <typename T, typename Context>
void MinimumRawKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
int axis,
DenseTensor* out) {
// allocate memory for out
dev_ctx.template Alloc<T>(out);
funcs::ElementwiseCompute<funcs::MinimumFunctor<T>, T>(
dev_ctx, x, y, axis, funcs::MinimumFunctor<T>(), out);
}
template <typename T, typename Context>
void ModuloRawKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
int axis,
DenseTensor* out) {
// allocate memory for out
dev_ctx.template Alloc<T>(out);
auto x_dims = x.dims();
auto y_dims = y.dims();
if (x_dims.size() >= y_dims.size()) {
funcs::ElementwiseCompute<funcs::ModuloFunctor<T>, T>(
dev_ctx, x, y, axis, funcs::ModuloFunctor<T>(), out);
} else {
funcs::ElementwiseCompute<funcs::InverseModuloFunctor<T>, T>(
dev_ctx, x, y, axis, funcs::InverseModuloFunctor<T>(), out);
}
}
// Create the definition of Add
DEFINE_CPU_ELEMENTWISE_OP(Add)
......@@ -138,3 +181,29 @@ PD_REGISTER_KERNEL(multiply_raw,
complex64,
complex128,
phi::dtype::bfloat16) {}
PD_REGISTER_KERNEL(maximum_raw,
CPU,
ALL_LAYOUT,
phi::MaximumRawKernel,
float,
double,
int,
int64_t,
phi::dtype::bfloat16) {}
PD_REGISTER_KERNEL(minimum_raw,
CPU,
ALL_LAYOUT,
phi::MinimumRawKernel,
float,
double,
int,
int64_t,
phi::dtype::bfloat16) {}
PD_REGISTER_KERNEL(modulo_raw,
CPU,
ALL_LAYOUT,
phi::ModuloRawKernel,
float,
double,
int,
int64_t) {}
......@@ -142,4 +142,21 @@ void ElementwiseFMinGradKernel(const Context& dev_ctx,
DenseTensor* x_grad,
DenseTensor* y_grad);
template <typename T, typename Context>
void MaximumGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
const DenseTensor& dout,
int axis,
DenseTensor* dx,
DenseTensor* dy);
template <typename T, typename Context>
void MinimumGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
const DenseTensor& dout,
int axis,
DenseTensor* dx,
DenseTensor* dy);
} // namespace phi
......@@ -55,6 +55,32 @@ void MultiplyKernel(const Context& dev_ctx,
MultiplyRawKernel<T>(dev_ctx, x, y, axis, out);
}
template <typename T, typename Context>
void MaximumKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
DenseTensor* out) {
int axis = -1;
MaximumRawKernel<T>(dev_ctx, x, y, axis, out);
}
template <typename T, typename Context>
void MinimumKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
DenseTensor* out) {
int axis = -1;
MinimumRawKernel<T>(dev_ctx, x, y, axis, out);
}
template <typename T, typename Context>
void ModuloKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
DenseTensor* out) {
int axis = -1;
ModuloRawKernel<T>(dev_ctx, x, y, axis, out);
}
} // namespace phi
using complex64 = ::phi::dtype::complex<float>;
......@@ -105,6 +131,26 @@ PD_REGISTER_KERNEL(multiply,
complex64,
complex128,
phi::dtype::bfloat16) {}
PD_REGISTER_KERNEL(maximum,
CPU,
ALL_LAYOUT,
phi::MaximumKernel,
float,
double,
int,
int64_t,
phi::dtype::bfloat16) {}
PD_REGISTER_KERNEL(minimum,
CPU,
ALL_LAYOUT,
phi::MinimumKernel,
float,
double,
int,
int64_t,
phi::dtype::bfloat16) {}
PD_REGISTER_KERNEL(
modulo, CPU, ALL_LAYOUT, phi::ModuloKernel, float, double, int, int64_t) {}
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
......@@ -158,4 +204,26 @@ PD_REGISTER_KERNEL(multiply,
phi::dtype::float16,
complex64,
complex128) {}
PD_REGISTER_KERNEL(maximum,
GPU,
ALL_LAYOUT,
phi::MaximumKernel,
float,
double,
int,
int64_t,
phi::dtype::float16,
phi::dtype::bfloat16) {}
PD_REGISTER_KERNEL(minimum,
GPU,
ALL_LAYOUT,
phi::MinimumKernel,
float,
double,
int,
int64_t,
phi::dtype::float16,
phi::dtype::bfloat16) {}
PD_REGISTER_KERNEL(
modulo, GPU, ALL_LAYOUT, phi::ModuloKernel, float, double, int, int64_t) {}
#endif
......@@ -85,6 +85,45 @@ void MultiplyKernel(const Context& dev_ctx,
const DenseTensor& y,
DenseTensor* out);
template <typename T, typename Context>
void MaximumRawKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
int axis,
DenseTensor* out);
template <typename T, typename Context>
void MaximumKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
DenseTensor* out);
template <typename T, typename Context>
void MinimumRawKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
int axis,
DenseTensor* out);
template <typename T, typename Context>
void MinimumKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
DenseTensor* out);
template <typename T, typename Context>
void ModuloRawKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
int axis,
DenseTensor* out);
template <typename T, typename Context>
void ModuloKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
DenseTensor* out);
template <typename T, typename Context>
DenseTensor Add(const Context& dev_ctx,
const DenseTensor& x,
......@@ -129,4 +168,36 @@ DenseTensor Multiply(const Context& dev_ctx,
return dense_out;
}
template <typename T, typename Context>
DenseTensor Maximum(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& y) {
DenseTensor dense_out;
MetaTensor meta_out(&dense_out);
ElementwiseInferMeta(x, y, &meta_out);
MaximumKernel<T, Context>(dev_ctx, x, y, &dense_out);
return dense_out;
}
template <typename T, typename Context>
DenseTensor Minimum(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& y) {
DenseTensor dense_out;
MetaTensor meta_out(&dense_out);
ElementwiseInferMeta(x, y, &meta_out);
MinimumKernel<T, Context>(dev_ctx, x, y, &dense_out);
return dense_out;
}
template <typename T, typename Context>
DenseTensor Modulo(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& y) {
DenseTensor dense_out;
MetaTensor meta_out(&dense_out);
ElementwiseInferMeta(x, y, &meta_out);
ModuloKernel<T, Context>(dev_ctx, x, y, &dense_out);
return dense_out;
}
} // namespace phi
......@@ -422,5 +422,121 @@ struct MultiplyGradXYFunctor<ComplexType<InT>, ComplexType<OutT>> {
}
};
// Maximum
template <typename T>
struct MaximumFunctor {
inline HOSTDEVICE T operator()(const T a, const T b) const {
return a > b ? a : b;
}
};
template <typename T>
struct MaxGradXFunctor {
inline HOSTDEVICE T operator()(const T x, const T y, const T dout) const {
return dout * static_cast<T>(x > y);
}
};
template <typename T>
struct MaxGradYFunctor {
inline HOSTDEVICE T operator()(const T x, const T y, const T dout) const {
return dout * static_cast<T>(x <= y);
}
};
template <typename InT, typename OutT>
struct MaxGradXYFunctor {
inline HOSTDEVICE phi::Array<OutT, 2> operator()(const InT x,
const InT y,
const InT dout) {
phi::Array<OutT, 2> outs;
// dx = dout * (x > y)
outs[0] = static_cast<OutT>(dout * static_cast<InT>(x > y));
// dy = dout * (x <= y)
outs[1] = static_cast<OutT>(dout * static_cast<InT>(x <= y));
return outs;
}
};
// Minimum
template <typename T>
struct MinimumFunctor {
inline HOSTDEVICE T operator()(const T a, const T b) const {
return a < b ? a : b;
}
};
template <typename T>
struct MinGradXFunctor {
inline HOSTDEVICE T operator()(const T x, const T y, const T dout) const {
return dout * static_cast<T>(x < y);
}
};
template <typename T>
struct MinGradYFunctor {
inline HOSTDEVICE T operator()(const T x, const T y, const T dout) const {
return dout * static_cast<T>(x >= y);
}
};
template <typename InT, typename OutT>
struct MinGradXYFunctor {
inline HOSTDEVICE phi::Array<OutT, 2> operator()(const InT x,
const InT y,
const InT dout) {
phi::Array<OutT, 2> outs;
// dx = dout * (x < y)
outs[0] = static_cast<OutT>(dout * static_cast<InT>(x < y));
// dy = dout * (x >= y)
outs[1] = static_cast<OutT>(dout * static_cast<InT>(x >= y));
return outs;
}
};
// Modulo
template <typename T, typename Enable = void>
struct ModuloFunctor {
inline HOSTDEVICE T operator()(const T a, const T b) const {
T res = a % b;
// Accoding to #PR26732: in dividen % divsor
// remainder shall have the same sign as divsor.
if ((res != 0) && ((b ^ res) < 0)) res += b;
return res;
}
};
template <typename T>
struct ModuloFunctor<
T,
typename std::enable_if_t<std::is_floating_point<T>::value>> {
inline HOSTDEVICE T operator()(const T a, const T b) const {
T res = fmod(a, b);
// Accoding to #PR26732: in dividen % divsor
// remainder shall have the same sign as divsor.
if ((res != 0) && ((res < 0) != (b < 0))) res += b;
return res;
}
};
template <typename T, typename Enable = void>
struct InverseModuloFunctor {
inline HOSTDEVICE T operator()(const T a, const T b) const {
T res = b % a;
if ((res != 0) && ((res < 0) != (a < 0))) res += a;
return res;
}
};
template <typename T>
struct InverseModuloFunctor<
T,
typename std::enable_if_t<std::is_floating_point<T>::value>> {
inline HOSTDEVICE T operator()(const T a, const T b) const {
T res = fmod(b, a);
if ((res != 0) && ((a < 0) != (res < 0))) res += a;
return res;
}
};
} // namespace funcs
} // namespace phi
......@@ -148,6 +148,67 @@ void MultiplyGradKernel(const Context& dev_ctx,
ElementwiseMulGrad<T>(dev_ctx, x, y, dout, dx, dy, axis);
}
template <typename T, typename Context>
void MaximumGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
const DenseTensor& dout,
int axis,
DenseTensor* dx,
DenseTensor* dy) {
const auto place = dev_ctx.GetPlace();
if (dx != nullptr && dy != nullptr) {
std::vector<const DenseTensor*> ins = {&x, &y, &dout};
GetGradXAndYOut<ElementwiseType::kTernary, T>(
dev_ctx,
place,
axis,
ins,
dout,
dx,
dy,
funcs::MaxGradXYFunctor<T, T>());
} else if (dx != nullptr && dy == nullptr) {
std::vector<const DenseTensor*> ins = {&x, &y, &dout};
GetGradXOrYOut<ElementwiseType::kBinary, T>(
dev_ctx, place, axis, ins, dout, dx, funcs::MaxGradXFunctor<T>());
} else if (dy != nullptr && dx == nullptr) {
std::vector<const DenseTensor*> ins = {&x, &y, &dout};
GetGradXOrYOut<ElementwiseType::kTernary, T>(
dev_ctx, place, axis, ins, dout, dy, funcs::MaxGradYFunctor<T>());
}
}
template <typename T, typename Context>
void MinimumGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
const DenseTensor& dout,
int axis,
DenseTensor* dx,
DenseTensor* dy) {
const auto place = dev_ctx.GetPlace();
if (dx != nullptr && dy != nullptr) {
std::vector<const DenseTensor*> ins = {&x, &y, &dout};
GetGradXAndYOut<ElementwiseType::kTernary, T>(
dev_ctx,
place,
axis,
ins,
dout,
dx,
dy,
funcs::MinGradXYFunctor<T, T>());
} else if (dx != nullptr && dy == nullptr) {
std::vector<const DenseTensor*> ins = {&x, &y, &dout};
GetGradXOrYOut<ElementwiseType::kBinary, T>(
dev_ctx, place, axis, ins, dout, dx, funcs::MinGradXFunctor<T>());
} else if (dy != nullptr && dx == nullptr) {
std::vector<const DenseTensor*> ins = {&x, &y, &dout};
GetGradXOrYOut<ElementwiseType::kTernary, T>(
dev_ctx, place, axis, ins, dout, dy, funcs::MinGradYFunctor<T>());
}
}
} // namespace phi
PD_REGISTER_KERNEL(add_grad,
......@@ -299,3 +360,25 @@ PD_REGISTER_KERNEL(fmin_grad,
double,
int,
int64_t) {}
PD_REGISTER_KERNEL(maximum_grad,
GPU,
ALL_LAYOUT,
phi::MaximumGradKernel,
float,
double,
int,
int64_t,
phi::dtype::float16,
phi::dtype::bfloat16) {}
PD_REGISTER_KERNEL(minimum_grad,
GPU,
ALL_LAYOUT,
phi::MinimumGradKernel,
float,
double,
int,
int64_t,
phi::dtype::float16,
phi::dtype::bfloat16) {}
......@@ -49,6 +49,12 @@ DEFINE_CUDA_ELEMENTWISE_OP(Subtract)
DEFINE_CUDA_ELEMENTWISE_OP(Multiply)
// Create the definition of Divide
DEFINE_CUDA_ELEMENTWISE_OP(Divide)
// Create the definition of Maximum
DEFINE_CUDA_ELEMENTWISE_OP(Maximum)
// Create the definition of Minimum
DEFINE_CUDA_ELEMENTWISE_OP(Minimum)
// Create the definition of Modulo
DEFINE_CUDA_ELEMENTWISE_OP(Modulo)
} // namespace phi
......@@ -114,3 +120,31 @@ PD_REGISTER_KERNEL(multiply_raw,
complex64,
complex128,
bfloat16) {}
PD_REGISTER_KERNEL(maximum_raw,
GPU,
ALL_LAYOUT,
phi::MaximumRawKernel,
float,
double,
int,
int64_t,
float16,
bfloat16) {}
PD_REGISTER_KERNEL(minimum_raw,
GPU,
ALL_LAYOUT,
phi::MinimumRawKernel,
float,
double,
int,
int64_t,
float16,
bfloat16) {}
PD_REGISTER_KERNEL(modulo_raw,
GPU,
ALL_LAYOUT,
phi::ModuloRawKernel,
float,
double,
int,
int64_t) {}
......@@ -628,4 +628,42 @@ void MultiplyTripleGradKernel(const Context& dev_ctx,
}
}
/*
******************************
Maximum Grad
******************************
*/
template <typename T>
struct MaxGradDx {
HOSTDEVICE T operator()(T x, T y, T out, T dout) const {
return dout * static_cast<T>(x > y);
}
};
template <typename T>
struct MaxGradDy {
HOSTDEVICE T operator()(T x, T y, T out, T dout) const {
return dout * static_cast<T>(x <= y);
}
};
/*
******************************
Minimum Grad
******************************
*/
template <typename T>
struct MinGradDx {
HOSTDEVICE T operator()(T x, T y, T out, T dout) const {
return dout * static_cast<T>(x < y);
}
};
template <typename T>
struct MinGradDy {
HOSTDEVICE T operator()(T x, T y, T out, T dout) const {
return dout * static_cast<T>(x >= y);
}
};
} // namespace phi
......@@ -55,6 +55,33 @@ KernelSignature ElementwiseDivOpArgumentMapping(
return KernelSignature("divide_raw", {"X", "Y"}, {"axis"}, {"Out"});
}
KernelSignature ElementwiseMaxOpArgumentMapping(
const ArgumentMappingContext& ctx) {
int axis = paddle::any_cast<int>(ctx.Attr("axis"));
if (axis == -1) {
return KernelSignature("maximum", {"X", "Y"}, {}, {"Out"});
}
return KernelSignature("maximum_raw", {"X", "Y"}, {"axis"}, {"Out"});
}
KernelSignature ElementwiseMinOpArgumentMapping(
const ArgumentMappingContext& ctx) {
int axis = paddle::any_cast<int>(ctx.Attr("axis"));
if (axis == -1) {
return KernelSignature("minimum", {"X", "Y"}, {}, {"Out"});
}
return KernelSignature("minimum_raw", {"X", "Y"}, {"axis"}, {"Out"});
}
KernelSignature ElementwiseModOpArgumentMapping(
const ArgumentMappingContext& ctx) {
int axis = paddle::any_cast<int>(ctx.Attr("axis"));
if (axis == -1) {
return KernelSignature("modulo", {"X", "Y"}, {}, {"Out"});
}
return KernelSignature("modulo_raw", {"X", "Y"}, {"axis"}, {"Out"});
}
KernelSignature ElementwiseAddGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature("add_grad",
......@@ -158,12 +185,30 @@ KernelSignature ElementwiseMulTripleGradOpArgumentMapping(
{"D_X", "D_Y", "D_DOut", "D_DDX", "D_DDY"});
}
KernelSignature ElementwiseMaxGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature("maximum_grad",
{"X", "Y", GradVarName("Out")},
{"axis"},
{GradVarName("X"), GradVarName("Y")});
}
KernelSignature ElementwiseMinGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature("minimum_grad",
{"X", "Y", GradVarName("Out")},
{"axis"},
{GradVarName("X"), GradVarName("Y")});
}
} // namespace phi
PD_REGISTER_BASE_KERNEL_NAME(elementwise_add, add);
PD_REGISTER_BASE_KERNEL_NAME(elementwise_sub, subtract);
PD_REGISTER_BASE_KERNEL_NAME(elementwise_mul, multiply);
PD_REGISTER_BASE_KERNEL_NAME(elementwise_div, divide);
PD_REGISTER_BASE_KERNEL_NAME(elementwise_max, maximum);
PD_REGISTER_BASE_KERNEL_NAME(elementwise_min, minimum);
PD_REGISTER_BASE_KERNEL_NAME(elementwise_mod, modulo);
PD_REGISTER_BASE_KERNEL_NAME(elementwise_add_grad, add_grad);
PD_REGISTER_BASE_KERNEL_NAME(elementwise_add_grad_grad, add_double_grad);
PD_REGISTER_BASE_KERNEL_NAME(elementwise_add_triple_grad, add_triple_grad);
......@@ -178,6 +223,8 @@ PD_REGISTER_BASE_KERNEL_NAME(elementwise_fmax, fmax);
PD_REGISTER_BASE_KERNEL_NAME(elementwise_fmin, fmin);
PD_REGISTER_BASE_KERNEL_NAME(elementwise_fmax_grad, fmax_grad);
PD_REGISTER_BASE_KERNEL_NAME(elementwise_fmin_grad, fmin_grad);
PD_REGISTER_BASE_KERNEL_NAME(elementwise_max_grad, maximum_grad);
PD_REGISTER_BASE_KERNEL_NAME(elementwise_min_grad, minimum_grad);
PD_REGISTER_ARG_MAPPING_FN(elementwise_add,
phi::ElementwiseAddOpArgumentMapping);
......@@ -187,6 +234,12 @@ PD_REGISTER_ARG_MAPPING_FN(elementwise_mul,
phi::ElementwiseMulOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(elementwise_div,
phi::ElementwiseDivOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(elementwise_max,
phi::ElementwiseMaxOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(elementwise_min,
phi::ElementwiseMinOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(elementwise_mod,
phi::ElementwiseModOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(elementwise_add_grad,
phi::ElementwiseAddGradOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(elementwise_add_grad_grad,
......@@ -211,8 +264,11 @@ PD_REGISTER_ARG_MAPPING_FN(elementwise_fmax,
phi::ElementwiseFMaxOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(elementwise_fmin,
phi::ElementwiseFMinOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(elementwise_fmax_grad,
phi::ElementwiseFMaxGradOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(elementwise_fmin_grad,
phi::ElementwiseFMinGradOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(elementwise_max_grad,
phi::ElementwiseMaxGradOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(elementwise_min_grad,
phi::ElementwiseMinGradOpArgumentMapping);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册