未验证 提交 aed6faf2 编写于 作者: S shentanyue 提交者: GitHub

[Phi] Migrate gelu/log_softmax/prelu op kernel and infershape (#40393)

* add gelu

* fix gelu

* add log_softmax

* add prelu kernel and prelu/gelu/logsoftmax infershape

* fix

* fix

* fix

* fix

* fix ci

* log_softmax rewrite

* fix

* fix

* fix conflict

* fix compile error

* fix comment

* fix

* ci_fix
Co-authored-by: NYan Li <liyan665@gmail.com>
上级 64a7cbd3
......@@ -32,8 +32,9 @@ USE_OP(conv2d_transpose);
USE_OP_DEVICE_KERNEL(conv2d_transpose, MKLDNN);
USE_OP_ITSELF(elementwise_add);
USE_OP_DEVICE_KERNEL(elementwise_add, MKLDNN);
USE_OP(gelu);
USE_OP_ITSELF(gelu);
USE_OP_DEVICE_KERNEL(gelu, MKLDNN);
PD_DECLARE_ARG_MAPPING_FN(gelu);
namespace paddle {
namespace framework {
......
......@@ -18,6 +18,7 @@
#include <unordered_set>
#include <boost/logic/tribool.hpp>
#include "paddle/fluid/framework/ir/pass_tester_helper.h"
#include "paddle/fluid/framework/op_registry.h"
......@@ -27,10 +28,11 @@ USE_OP_ITSELF(elementwise_add);
USE_OP_DEVICE_KERNEL(elementwise_add, MKLDNN);
USE_OP_ITSELF(leaky_relu);
USE_OP_DEVICE_KERNEL(leaky_relu, MKLDNN);
USE_OP(gelu);
USE_OP_ITSELF(gelu);
USE_OP_ITSELF(relu);
USE_OP_ITSELF(tanh);
USE_OP_DEVICE_KERNEL(tanh, MKLDNN);
PD_DECLARE_ARG_MAPPING_FN(gelu);
namespace paddle {
namespace framework {
......
......@@ -198,10 +198,7 @@ class EltwiseMKLDNNGradKernel : public ElemwiseGradKernel<T> {
platform::EventRole::kUniqueOp);
reorder_p->execute(astream, *reorder_src_memory_p, *dst_memory);
}
// elementwise_mul & elementwise_div
else {
} else { // elementwise_mul & elementwise_div
platform::BinaryMKLDNNHandler<T> binary_handler(
BINARY_OP, axis, onednn_engine, ctx.GetPlace(), dout, y, dx, 1.0f,
1.0f, 1.0f);
......@@ -253,10 +250,7 @@ class EltwiseMKLDNNGradKernel : public ElemwiseGradKernel<T> {
} else {
broadcast_src_memory = reorder_src_memory_p;
}
}
// elementwise_mul & elementwise_div
else {
} else { // elementwise_mul & elementwise_div
std::unordered_map<int, dnnl::memory> args;
std::shared_ptr<dnnl::binary> binary_prim;
std::shared_ptr<dnnl::memory> post_op_memory;
......
......@@ -14,10 +14,11 @@ limitations under the License. */
#include <memory>
#include <string>
#include <unordered_map>
#include "paddle/fluid/operators/gelu_op.h"
#include "paddle/fluid/platform/float16.h"
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/unary.h"
namespace paddle {
namespace operators {
......@@ -29,18 +30,6 @@ class GeluOp : public framework::OperatorWithKernel {
const framework::AttributeMap &attrs)
: OperatorWithKernel(type, inputs, outputs, attrs) {}
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true,
platform::errors::InvalidArgument(
"Input(%s) of GeluOp should not be null.", "X"));
PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true,
platform::errors::InvalidArgument(
"Output(%s) of GeluOp should not be null.", "Out"));
ctx->ShareDim("X", /*->*/ "Out");
ctx->ShareLoD("X", /*->*/ "Out");
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
......@@ -156,13 +145,10 @@ class GeluGradOpMaker : public framework::SingleGradOpMaker<T> {
namespace ops = paddle::operators;
DECLARE_INFER_SHAPE_FUNCTOR(gelu, GeluInferShapeFunctor,
PD_INFER_META(phi::UnchangedInferMeta));
REGISTER_OPERATOR(gelu, ops::GeluOp, ops::GeluOpMaker,
ops::GeluGradOpMaker<paddle::framework::OpDesc>,
ops::GeluGradOpMaker<paddle::imperative::OpBase>);
ops::GeluGradOpMaker<paddle::imperative::OpBase>,
GeluInferShapeFunctor);
REGISTER_OPERATOR(gelu_grad, ops::GeluGradOp);
REGISTER_OP_CPU_KERNEL(
gelu, ops::GeluKernel<paddle::platform::CPUDeviceContext, float>,
ops::GeluKernel<paddle::platform::CPUDeviceContext, double>);
REGISTER_OP_CPU_KERNEL(
gelu_grad, ops::GeluGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::GeluGradKernel<paddle::platform::CPUDeviceContext, double>);
......@@ -15,7 +15,9 @@ limitations under the License. */
#include <memory>
#include <string>
#include "paddle/fluid/operators/gelu_op.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/platform/device/npu/npu_op_runner.h"
namespace paddle {
......
......@@ -30,7 +30,7 @@ limitations under the License. */
namespace f = paddle::framework;
namespace p = paddle::platform;
USE_OP(gelu);
USE_OP_ITSELF(gelu);
USE_OP_DEVICE_KERNEL(gelu, NPU);
template <typename T>
......
......@@ -14,9 +14,9 @@ limitations under the License. */
#include <memory>
#include <string>
#include "paddle/fluid/operators/gelu_op.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/tensor.h"
namespace paddle {
namespace operators {
......
......@@ -12,10 +12,13 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/log_softmax_op.h"
#include <string>
#include <unordered_map>
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/common_infer_shape_functions.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/unary.h"
namespace paddle {
namespace operators {
......@@ -24,10 +27,6 @@ class LogSoftmaxOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
return UnaryOpUnchangedInferShapeCheckAxis(ctx);
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
......@@ -123,18 +122,11 @@ class LogSoftmaxGradOpMaker : public framework::SingleGradOpMaker<T> {
} // namespace paddle
namespace ops = paddle::operators;
DECLARE_INFER_SHAPE_FUNCTOR(log_softmax, LogSoftmaxInferShapeFunctor,
PD_INFER_META(phi::UnchangedInferMetaCheckAxis));
REGISTER_OPERATOR(log_softmax, ops::LogSoftmaxOp, ops::LogSoftmaxOpMaker,
ops::LogSoftmaxOpInferVarType,
ops::LogSoftmaxGradOpMaker<paddle::framework::OpDesc>,
ops::LogSoftmaxGradOpMaker<paddle::imperative::OpBase>);
ops::LogSoftmaxGradOpMaker<paddle::imperative::OpBase>,
LogSoftmaxInferShapeFunctor);
REGISTER_OPERATOR(log_softmax_grad, ops::LogSoftmaxGradOp);
REGISTER_OP_CPU_KERNEL(
log_softmax,
ops::LogSoftmaxKernel<paddle::platform::CPUDeviceContext, float>,
ops::LogSoftmaxKernel<paddle::platform::CPUDeviceContext, double>);
REGISTER_OP_CPU_KERNEL(
log_softmax_grad,
ops::LogSoftmaxGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::LogSoftmaxGradKernel<paddle::platform::CPUDeviceContext, double>);
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/log_softmax_op.h"
#include "paddle/phi/kernels/gpudnn/softmax_gpudnn.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename T>
class LogSoftmaxKernel<platform::CUDADeviceContext, T>
: public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
auto *x = ctx.Input<Tensor>("X");
auto *out = ctx.Output<Tensor>("Out");
out->mutable_data<T>(ctx.GetPlace());
int input_axis = ctx.Attr<int>("axis");
auto &dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
phi::SoftmaxForwardCUDAKernelDriver<T, true>(dev_ctx, *x, input_axis, out);
}
};
template <typename T>
class LogSoftmaxGradKernel<platform::CUDADeviceContext, T>
: public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
auto *out = ctx.Input<Tensor>("Out");
auto *dout = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto *dx = ctx.Output<Tensor>(framework::GradVarName("X"));
dx->mutable_data<T>(ctx.GetPlace());
int input_axis = ctx.Attr<int>("axis");
auto &dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
phi::SoftmaxBackwardCUDAKernelDriver<T, true>(dev_ctx, *out, *dout,
input_axis, dx);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
namespace plat = paddle::platform;
#ifdef PADDLE_WITH_HIP
REGISTER_OP_CUDA_KERNEL(
log_softmax, ops::LogSoftmaxKernel<plat::CUDADeviceContext, float>,
ops::LogSoftmaxKernel<plat::CUDADeviceContext, plat::float16>,
ops::LogSoftmaxKernel<plat::CUDADeviceContext, plat::bfloat16>);
REGISTER_OP_CUDA_KERNEL(
log_softmax_grad, ops::LogSoftmaxGradKernel<plat::CUDADeviceContext, float>,
ops::LogSoftmaxGradKernel<plat::CUDADeviceContext, plat::float16>,
ops::LogSoftmaxGradKernel<plat::CUDADeviceContext, plat::bfloat16>);
#else
REGISTER_OP_CUDA_KERNEL(
log_softmax, ops::LogSoftmaxKernel<plat::CUDADeviceContext, float>,
ops::LogSoftmaxKernel<plat::CUDADeviceContext, double>,
ops::LogSoftmaxKernel<plat::CUDADeviceContext, plat::float16>,
ops::LogSoftmaxKernel<plat::CUDADeviceContext, plat::bfloat16>);
REGISTER_OP_CUDA_KERNEL(
log_softmax_grad, ops::LogSoftmaxGradKernel<plat::CUDADeviceContext, float>,
ops::LogSoftmaxGradKernel<plat::CUDADeviceContext, double>,
ops::LogSoftmaxGradKernel<plat::CUDADeviceContext, plat::float16>,
ops::LogSoftmaxGradKernel<plat::CUDADeviceContext, plat::bfloat16>);
#endif
......@@ -12,8 +12,9 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/log_softmax_op.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/device/npu/npu_op_runner.h"
#include "paddle/phi/kernels/funcs/axis_utils.h"
namespace paddle {
namespace operators {
......@@ -27,7 +28,7 @@ class LogSoftmaxNPUKernel : public framework::OpKernel<T> {
auto* X = ctx.Input<framework::Tensor>("X");
auto* Out = ctx.Output<framework::Tensor>("Out");
const int rank = X->dims().size();
const int axis = CanonicalAxis(ctx.Attr<int>("axis"), rank);
const int axis = phi::funcs::CanonicalAxis(ctx.Attr<int>("axis"), rank);
Out->mutable_data<T>(ctx.GetPlace());
if (X->numel() != 0) {
......@@ -47,7 +48,7 @@ class LogSoftmaxGradNPUKernel : public framework::OpKernel<T> {
auto* dOut = ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
auto* dX = ctx.Output<framework::Tensor>(framework::GradVarName("X"));
const int rank = dOut->dims().size();
const int axis = CanonicalAxis(ctx.Attr<int>("axis"), rank);
const int axis = phi::funcs::CanonicalAxis(ctx.Attr<int>("axis"), rank);
// allocate memory on device.
dX->mutable_data<T>(ctx.GetPlace());
......
......@@ -9,14 +9,19 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/prelu_op.h"
#include <memory>
#include <string>
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/binary.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
framework::OpKernelType innerGetKernelTypeForVar(
const Tensor &tensor, const framework::OpKernelType &expected_kernel_type) {
#ifdef PADDLE_WITH_MKLDNN
......@@ -44,95 +49,6 @@ class PReluOp : public framework::OperatorWithKernel {
const framework::AttributeMap &attrs)
: OperatorWithKernel(type, inputs, outputs, attrs) {}
void InferShape(framework::InferShapeContext *ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "prelu");
OP_INOUT_CHECK(ctx->HasInput("Alpha"), "Input", "Alpha", "prelu");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "prelu");
auto x_dim = ctx->GetInputDim("X");
std::string mode = ctx->Attrs().Get<std::string>("mode");
if (mode == "all") {
PADDLE_ENFORCE_EQ(phi::product(ctx->GetInputDim("Alpha")), 1,
platform::errors::InvalidArgument(
"For mode 'all', size of weight Alpha must be one. "
"But recevied alpha's size: %d.",
product(ctx->GetInputDim("Alpha"))));
} else if (mode == "channel") {
auto x_rank = x_dim.size();
PADDLE_ENFORCE_GE(x_rank, 2,
platform::errors::InvalidArgument(
"For mode 'channel', rank of input X must be "
"equal or larger than 2. But recevied X's "
"rank: %d",
x_rank));
const std::string data_format_str =
ctx->Attrs().Get<std::string>("data_format");
PADDLE_ENFORCE_EQ(data_format_str == "NCHW" || data_format_str == "NHWC",
true,
platform::errors::InvalidArgument(
"For mode 'channel', data_format must be one of "
"NCHW and NHWC. But recevied data_format: %s",
data_format_str));
if (data_format_str == "NCHW" || ctx->IsRunMKLDNNKernel()) {
PADDLE_ENFORCE_EQ(
product(ctx->GetInputDim("Alpha")) == x_dim[1], true,
platform::errors::InvalidArgument(
"For mode 'channel', size of weight Alpha must be "
"equal to the number of channels of input(x). But "
"recevied alpha's size: %d, x_dim[1]: %d",
product(ctx->GetInputDim("Alpha")), x_dim[1]));
} else {
PADDLE_ENFORCE_EQ(
product(ctx->GetInputDim("Alpha")) == x_dim[x_rank - 1], true,
platform::errors::InvalidArgument(
"For mode 'channel', size of weight Alpha must be "
"equal to the number of channels of input(x). But "
"recevied alpha's size: %d, x_dim[%d]: %d",
product(ctx->GetInputDim("Alpha")), x_rank - 1,
x_dim[x_rank - 1]));
}
} else if (mode == "element") {
auto alpha_dim = ctx->GetInputDim("Alpha");
auto alpha_rank = alpha_dim.size();
auto x_rank = x_dim.size();
PADDLE_ENFORCE_GE(x_rank, 1,
platform::errors::InvalidArgument(
"For mode 'element', rank of input X must be "
"equal or larger than 2. But recevied X's "
"rank: %d",
x_rank));
PADDLE_ENFORCE_EQ(
alpha_rank, x_rank,
platform::errors::InvalidArgument(
"For mode 'element', rank of weight Alpha must be ",
"equal to the rank of input(x). But recevied alpha's rank: %d, "
"x's rank: %d.",
alpha_rank, x_rank));
size_t x_product = 1;
size_t alpha_product = 1;
for (int64_t i = x_rank - 1; i > 0; i--) {
x_product *= x_dim[i];
alpha_product *= alpha_dim[i];
}
PADDLE_ENFORCE_EQ(
alpha_product, x_product,
platform::errors::InvalidArgument(
"For mode 'element', the size of weight Alpha must be "
"equal to the size of input(x). But recevied alpha's size: %d, "
"x's size: %d.",
alpha_product, x_product));
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"Attr(mode) of prelu must be one of 'all', 'channel', or 'element'. "
"But recevied "
"mode: '%s'.",
mode));
}
ctx->ShareDim("X", /*->*/ "Out");
ctx->ShareLoD("X", /*->*/ "Out");
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
......@@ -268,13 +184,10 @@ class PReluGradOpMaker : public framework::SingleGradOpMaker<T> {
namespace ops = paddle::operators;
DECLARE_INFER_SHAPE_FUNCTOR(prelu, PReluInferShapeFunctor,
PD_INFER_META(phi::PReluInferMeta));
REGISTER_OPERATOR(prelu, ops::PReluOp, ops::PReluOpMaker,
ops::PReluGradOpMaker<paddle::framework::OpDesc>,
ops::PReluGradOpMaker<paddle::imperative::OpBase>);
ops::PReluGradOpMaker<paddle::imperative::OpBase>,
PReluInferShapeFunctor);
REGISTER_OPERATOR(prelu_grad, ops::PReluGradOp);
REGISTER_OP_CPU_KERNEL(
prelu, ops::PReluKernel<paddle::platform::CPUDeviceContext, float>,
ops::PReluKernel<paddle::platform::CPUDeviceContext, double>);
REGISTER_OP_CPU_KERNEL(
prelu_grad, ops::PReluGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::PReluGradKernel<paddle::platform::CPUDeviceContext, double>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <string>
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/prelu.h"
#include "paddle/fluid/operators/prelu_op.h"
#include "paddle/fluid/operators/reduce_ops/reduce_op.cu.h"
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
#define CUDA_NUM_THREADS 1024
inline static int PADDLE_GET_BLOCKS(const int N) {
return (N + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS;
}
template <typename DeviceContext, typename T>
class CUDAPReluKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* x = context.Input<Tensor>("X");
auto* alpha = context.Input<Tensor>("Alpha");
auto* out = context.Output<Tensor>("Out");
const T* x_ptr = x->data<T>();
T* o_ptr = out->mutable_data<T>(context.GetPlace());
const T* alpha_ptr = alpha->data<T>();
auto& mode = context.Attr<std::string>("mode");
auto& data_format = context.Attr<std::string>("data_format");
int numel = x->numel();
auto dim = x->dims();
auto x_rank = dim.size();
VLOG(4) << "dim[0]:" << dim[0] << ", dim[1]:" << dim[1] << ", dim["
<< x_rank - 1 << "]:" << dim[x_rank - 1] << ", numel:" << numel;
if (mode == "channel") {
bool channel_last = data_format == "NHWC";
size_t channel = channel_last ? dim[x_rank - 1] : dim[1];
math::PreluChannelWiseDirectCUDAFunctor<T> prelu_channel_wise;
prelu_channel_wise(context.cuda_device_context().stream(), x_ptr,
alpha_ptr, o_ptr, dim[0], channel, channel_last,
numel);
} else if (mode == "element") {
math::PreluElementWiseDirectCUDAFunctor<T> prelu_element_wise;
prelu_element_wise(context.cuda_device_context().stream(), x_ptr,
alpha_ptr, o_ptr, dim[0], numel);
} else {
math::PreluScalarDirectCUDAFunctor<T> prelu_scalar;
prelu_scalar(context.cuda_device_context().stream(), x_ptr, alpha_ptr,
o_ptr, numel);
}
}
};
enum PRELU_MODE { Element, ChannelFirst, ChannelLast, Scalar };
template <typename T>
__global__ void PReluOpGradKernel(const T* x_ptr, const T* alpha_ptr,
const T* dy_ptr, T* dx_ptr, T* dalpha_ptr,
size_t channel_num, size_t plane_size,
size_t spatial_size, size_t numel,
PRELU_MODE mode) {
CUDA_KERNEL_LOOP(index, numel) {
T scale;
if (mode == Element) {
size_t element_index = index % spatial_size;
scale = alpha_ptr[element_index];
} else if (mode == ChannelFirst) {
size_t temp = index / plane_size;
size_t channel_index = temp % channel_num;
scale = alpha_ptr[channel_index];
} else if (mode == ChannelLast) {
size_t channel_index = index % channel_num;
scale = alpha_ptr[channel_index];
} else {
scale = alpha_ptr[0];
}
T x = x_ptr[index];
T dy = dy_ptr[index];
T zero = static_cast<T>(0);
if (dx_ptr != nullptr) dx_ptr[index] = (x > zero) ? dy : scale * dy;
if (dalpha_ptr != nullptr) dalpha_ptr[index] = (x > zero) ? zero : x * dy;
}
}
template <typename T>
class PreluOpGradFunctor {
public:
void operator()(gpuStream_t stream, const T* x, const T* alpha, const T* dy,
T* dx, T* dalpha, const framework::DDim& input_dims,
PRELU_MODE mode) {
size_t numel = 1;
for (size_t i = 0; i < input_dims.size(); ++i) {
numel *= input_dims[i];
}
size_t plane_size = numel / input_dims[0] / input_dims[1];
size_t spatial_size = numel / input_dims[0];
size_t channel =
mode == ChannelLast ? input_dims[input_dims.size() - 1] : input_dims[1];
PReluOpGradKernel<
T><<<PADDLE_GET_BLOCKS(numel), CUDA_NUM_THREADS, 0, stream>>>(
x, alpha, dy, dx, dalpha, channel, plane_size, spatial_size, numel,
mode);
}
};
template <typename DeviceContext, typename T>
class CUDAPReluGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* x = context.Input<Tensor>("X");
auto* alpha = context.Input<Tensor>("Alpha");
auto* dx = context.Output<Tensor>(framework::GradVarName("X"));
auto* dy = context.Input<Tensor>(framework::GradVarName("Out"));
auto* dalpha = context.Output<Tensor>(framework::GradVarName("Alpha"));
const T* x_ptr = x->data<T>();
const T* alpha_ptr = alpha->data<T>();
const T* dy_ptr = dy->data<T>();
T* dx_ptr = dx ? dx->mutable_data<T>(context.GetPlace()) : nullptr;
T* dalpha_ptr =
dalpha ? dalpha->mutable_data<T>(context.GetPlace()) : nullptr;
if (!dx && !dalpha) return;
auto& mode = context.Attr<std::string>("mode");
auto& data_format = context.Attr<std::string>("data_format");
int numel = x->numel();
auto dim = x->dims();
auto x_rank = dim.size();
std::vector<int> input_shape = phi::vectorize<int>(dim);
auto stream = context.cuda_device_context().stream();
T* dalpha_tmp_ptr;
Tensor dalpha_tmp;
if (dalpha_ptr == nullptr) {
dalpha_tmp_ptr = dalpha_ptr;
} else {
auto& dev_ctx = context.template device_context<DeviceContext>();
dalpha_tmp = context.AllocateTmpTensor<T, DeviceContext>(dim, dev_ctx);
dalpha_tmp_ptr = dalpha_tmp.mutable_data<T>(context.GetPlace());
}
PRELU_MODE m;
bool channel_last = false;
if (mode == "element") {
m = Element;
} else if (mode == "channel") {
channel_last = data_format == "NHWC";
m = channel_last ? ChannelLast : ChannelFirst;
} else {
m = Scalar;
}
PreluOpGradFunctor<T> prelu_grad;
prelu_grad(stream, x_ptr, alpha_ptr, dy_ptr, dx_ptr, dalpha_tmp_ptr, dim,
m);
if (dalpha_tmp_ptr == nullptr) return;
std::vector<int> reduce_dims;
for (size_t i = 0; i < dim.size(); i++) {
if (mode == "channel" && !channel_last && i == 1) continue;
if (mode == "channel" && channel_last && i == dim.size() - 1) continue;
if (mode == "element" && i != 0) continue;
reduce_dims.push_back(i);
}
TensorReduceImpl<T, T, kps::AddFunctor, kps::IdentityFunctor<T>>(
context.cuda_device_context(), dalpha_tmp, dalpha,
kps::IdentityFunctor<T>(), reduce_dims, stream);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(
prelu, ops::CUDAPReluKernel<paddle::platform::CUDADeviceContext, float>,
ops::CUDAPReluKernel<paddle::platform::CUDADeviceContext, plat::float16>,
ops::CUDAPReluKernel<paddle::platform::CUDADeviceContext, double>);
REGISTER_OP_CUDA_KERNEL(
prelu_grad,
ops::CUDAPReluGradKernel<paddle::platform::CUDADeviceContext, float>,
ops::CUDAPReluGradKernel<paddle::platform::CUDADeviceContext,
plat::float16>,
ops::CUDAPReluGradKernel<paddle::platform::CUDADeviceContext, double>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <string>
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/transform.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
using platform::Transform;
template <typename DeviceContext, typename T>
class PReluKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* x = context.Input<Tensor>("X");
auto* alpha = context.Input<Tensor>("Alpha");
auto* out = context.Output<Tensor>("Out");
const T* x_ptr = x->data<T>();
T* o_ptr = out->mutable_data<T>(context.GetPlace());
const T* alpha_ptr = alpha->data<T>();
auto& mode = context.Attr<std::string>("mode");
auto& data_format = context.Attr<std::string>("data_format");
int numel = x->numel();
auto dim = x->dims();
int index = 0;
int i = 0;
if (mode == "channel") {
if (data_format == "NCHW") {
int temp = 1;
for (int j = 2; j < dim.size(); j++) {
temp *= dim[j];
}
for (i = 0; i < numel; i++) {
index = (i / temp) % dim[1];
o_ptr[i] = x_ptr[i] > 0 ? x_ptr[i] : alpha_ptr[index] * x_ptr[i];
}
} else {
for (i = 0; i < numel; i++) {
index = i % dim[dim.size() - 1];
o_ptr[i] = x_ptr[i] > 0 ? x_ptr[i] : alpha_ptr[index] * x_ptr[i];
}
}
} else if (mode == "element") {
int temp = 1;
for (int j = 1; j < dim.size(); j++) {
temp *= dim[j];
}
for (i = 0; i < numel; i++) {
index = i % temp;
o_ptr[i] = x_ptr[i] > 0 ? x_ptr[i] : alpha_ptr[index] * x_ptr[i];
}
} else {
for (i = 0; i < numel; i++) {
o_ptr[i] = x_ptr[i] > 0 ? x_ptr[i] : alpha_ptr[0] * x_ptr[i];
}
}
}
};
template <typename DeviceContext, typename T>
class PReluGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* x = context.Input<Tensor>("X");
auto* dx = context.Output<Tensor>(framework::GradVarName("X"));
auto* dout = context.Input<Tensor>(framework::GradVarName("Out"));
auto* dalpha = context.Output<Tensor>(framework::GradVarName("Alpha"));
auto* alpha = context.Input<Tensor>("Alpha");
const T* alpha_ptr = alpha->data<T>();
const T* x_ptr = x->data<T>();
const T* dout_ptr = dout->data<T>();
std::string mode = context.Attr<std::string>("mode");
auto& data_format = context.Attr<std::string>("data_format");
int numel = x->numel();
auto dim = x->dims();
int index = 0;
int i = 0;
if (dx) {
T* dx_ptr = dx->mutable_data<T>(context.GetPlace());
if (mode == "channel") {
if (data_format == "NCHW") {
int temp = 1;
for (int j = 2; j < dim.size(); j++) {
temp *= dim[j];
}
for (i = 0; i < numel; i++) {
index = (i / temp) % dim[1];
dx_ptr[i] =
x_ptr[i] > 0 ? dout_ptr[i] : alpha_ptr[index] * dout_ptr[i];
}
} else {
for (i = 0; i < numel; i++) {
index = i % dim[dim.size() - 1];
dx_ptr[i] =
x_ptr[i] > 0 ? dout_ptr[i] : alpha_ptr[index] * dout_ptr[i];
}
}
} else if (mode == "element") {
int temp = 1;
for (int j = 1; j < dim.size(); j++) {
temp *= dim[j];
}
for (i = 0; i < numel; i++) {
index = i % temp;
dx_ptr[i] =
x_ptr[i] > 0 ? dout_ptr[i] : alpha_ptr[index] * dout_ptr[i];
}
} else {
for (i = 0; i < numel; i++) {
dx_ptr[i] = x_ptr[i] > 0 ? dout_ptr[i] : alpha_ptr[0] * dout_ptr[i];
}
}
}
index = 0;
if (dalpha) {
T* dalpha_ptr = dalpha->mutable_data<T>(context.GetPlace());
memset(dalpha_ptr, 0, sizeof(T) * dalpha->numel());
if (mode == "channel") {
if (data_format == "NCHW") {
int temp = 1;
for (int j = 2; j < dim.size(); j++) {
temp *= dim[j];
}
for (i = 0; i < numel; i++) {
index = (i / temp) % dim[1];
dalpha_ptr[index] += x_ptr[i] > 0 ? 0 : x_ptr[i] * dout_ptr[i];
}
} else {
for (i = 0; i < numel; i++) {
index = i % dim[dim.size() - 1];
dalpha_ptr[index] += x_ptr[i] > 0 ? 0 : x_ptr[i] * dout_ptr[i];
}
}
} else if (mode == "element") {
int temp = 1;
for (int j = 1; j < dim.size(); j++) {
temp *= dim[j];
}
for (i = 0; i < numel; i++) {
index = i % temp;
dalpha_ptr[index] += x_ptr[i] > 0 ? 0 : x_ptr[i] * dout_ptr[i];
}
} else {
for (i = 0; i < numel; i++) {
dalpha_ptr[0] += x_ptr[i] > 0 ? 0 : x_ptr[i] * dout_ptr[i];
}
}
}
// TODO(Guanzhong): add GPU kernels
}
};
} // namespace operators
} // namespace paddle
......@@ -918,6 +918,103 @@ void MvInferMeta(const MetaTensor& x, const MetaTensor& vec, MetaTensor* out) {
out->share_lod(x);
}
void PReluInferMeta(const MetaTensor& x,
const MetaTensor& alpha,
const std::string& mode,
const std::string& data_format,
MetaTensor* out,
MetaConfig config) {
auto x_dim = x.dims();
if (mode == "all") {
PADDLE_ENFORCE_EQ(phi::product(alpha.dims()),
1,
phi::errors::InvalidArgument(
"For mode 'all', size of weight Alpha must be one. "
"But recevied alpha's size: %d.",
product(alpha.dims())));
} else if (mode == "channel") {
auto x_rank = x_dim.size();
PADDLE_ENFORCE_GE(x_rank,
2,
phi::errors::InvalidArgument(
"For mode 'channel', rank of input X must be "
"equal or larger than 2. But recevied X's "
"rank: %d",
x_rank));
PADDLE_ENFORCE_EQ(data_format == "NCHW" || data_format == "NHWC",
true,
phi::errors::InvalidArgument(
"For mode 'channel', data_format must be one of "
"NCHW and NHWC. But recevied data_format: %s",
data_format));
if (data_format == "NCHW" || config.is_run_mkldnn_kernel) {
PADDLE_ENFORCE_EQ(product(alpha.dims()) == x_dim[1],
true,
phi::errors::InvalidArgument(
"For mode 'channel', size of weight Alpha must be "
"equal to the number of channels of input(x). But "
"recevied alpha's size: %d, x_dim[1]: %d",
product(alpha.dims()),
x_dim[1]));
} else {
PADDLE_ENFORCE_EQ(product(alpha.dims()) == x_dim[x_rank - 1],
true,
phi::errors::InvalidArgument(
"For mode 'channel', size of weight Alpha must be "
"equal to the number of channels of input(x). But "
"recevied alpha's size: %d, x_dim[%d]: %d",
product(alpha.dims()),
x_rank - 1,
x_dim[x_rank - 1]));
}
} else if (mode == "element") {
auto alpha_dim = alpha.dims();
auto alpha_rank = alpha_dim.size();
auto x_rank = x_dim.size();
PADDLE_ENFORCE_GE(x_rank,
1,
phi::errors::InvalidArgument(
"For mode 'element', rank of input X must be "
"equal or larger than 2. But recevied X's "
"rank: %d",
x_rank));
PADDLE_ENFORCE_EQ(
alpha_rank,
x_rank,
phi::errors::InvalidArgument(
"For mode 'element', rank of weight Alpha must be ",
"equal to the rank of input(x). But recevied alpha's rank: %d, "
"x's rank: %d.",
alpha_rank,
x_rank));
size_t x_product = 1;
size_t alpha_product = 1;
for (int64_t i = x_rank - 1; i > 0; i--) {
x_product *= x_dim[i];
alpha_product *= alpha_dim[i];
}
PADDLE_ENFORCE_EQ(
alpha_product,
x_product,
phi::errors::InvalidArgument(
"For mode 'element', the size of weight Alpha must be "
"equal to the size of input(x). But recevied alpha's size: %d, "
"x's size: %d.",
alpha_product,
x_product));
} else {
PADDLE_THROW(phi::errors::InvalidArgument(
"Attr(mode) of prelu must be one of 'all', 'channel', or 'element'. "
"But recevied "
"mode: '%s'.",
mode));
}
out->set_dims(x_dim);
out->set_dtype(x.dtype());
out->set_layout(x.layout());
out->share_lod(x);
}
void SearchsortedInferMeta(const MetaTensor& sorted_sequence,
const MetaTensor& value,
bool out_int32,
......
......@@ -146,6 +146,13 @@ void MatmulInferMeta(const MetaTensor& x,
void MvInferMeta(const MetaTensor& x, const MetaTensor& vec, MetaTensor* out);
void PReluInferMeta(const MetaTensor& x,
const MetaTensor& alpha,
const std::string& mode,
const std::string& data_format,
MetaTensor* out,
MetaConfig config);
void SearchsortedInferMeta(const MetaTensor& sorted_sequence,
const MetaTensor& value,
bool out_int32,
......
......@@ -1650,7 +1650,7 @@ void UnchangedInferMetaCheckAxis(const MetaTensor& x,
PADDLE_ENFORCE_GE(
axis,
-rank,
errors::InvalidArgument(
phi::errors::InvalidArgument(
"Attr(axis) value should be in range [-R, R-1], "
"R is the rank of Input(X). But received axis: %d, R: %d.",
axis,
......
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/gelu_grad_kernel.h"
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#ifndef _USE_MATH_DEFINES
#define _USE_MATH_DEFINES
#endif
#include <algorithm>
#include <cmath>
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/float16.h"
#include "paddle/phi/kernels/funcs/blas/blas.h"
#ifdef PADDLE_WITH_MKLDNN
#include "paddle/fluid/platform/mkldnn_helper.h"
#endif
namespace paddle {
namespace operators {
#define GELU_CONSTANT 0.044715
template <typename T>
struct GeluFunctor {
template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Out out, bool approximate) const {
if (approximate) {
// gelu(x) = 0.5 * x * (1 + tanh(sqrt(2 / \pi) * (x + 0.044715 * x^{3})))
if (std::is_same<T, platform::float16>::value) {
VLOG(4) << "cast from float16 to float before computing";
auto casted_x = x.template cast<float>();
auto temp =
(static_cast<float>(M_2_SQRTPI * M_SQRT1_2) *
(casted_x + static_cast<float>(GELU_CONSTANT) * casted_x.cube()))
.tanh();
out.device(d) = (casted_x * static_cast<float>(0.5) *
(static_cast<float>(1) + temp))
.template cast<T>();
} else {
auto temp = (static_cast<T>(M_2_SQRTPI * M_SQRT1_2) *
(x + static_cast<T>(GELU_CONSTANT) * x.cube()))
.tanh();
out.device(d) = x * static_cast<T>(0.5) * (static_cast<T>(1) + temp);
}
} else {
#if defined(PADDLE_WITH_MKLML) && !defined(_WIN32) && !defined(__APPLE__) && \
!defined(__OSX__) && !defined(PADDLE_WITH_CUDA) && \
!defined(PADDLE_WITH_HIP)
auto x_data = x.data();
auto out_data = out.data();
int n = std::min(x.size(), out.size());
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/blas/blas.h"
#include "paddle/phi/kernels/funcs/blas/blas_impl.h"
#include "paddle/phi/kernels/funcs/eigen/common.h"
#include "paddle/phi/kernels/funcs/eigen/eigen_function.h"
#include "paddle/phi/kernels/gelu_kernel.h"
std::memset(out_data, 0, n * sizeof(T));
phi::funcs::CBlas<T>::AXPY(n, static_cast<T>(M_SQRT1_2), x_data, 1,
out_data, 1);
phi::funcs::CBlas<T>::VMERF(n, out_data, out_data, VML_LA);
for (int i = 0; i < n; i++) {
out_data[i] += static_cast<T>(1);
}
phi::funcs::CBlas<T>::VMUL(n, x_data, out_data, out_data);
for (int i = 0; i < n; i++) {
out_data[i] *= static_cast<T>(0.5);
}
#else
// gelu(x) = 0.5 * x * (1 + erf(x / sqrt(2)))
if (std::is_same<T, platform::float16>::value) {
VLOG(4) << "cast from float16 to float before computing";
auto casted_x = x.template cast<float>();
auto temp = (casted_x * static_cast<float>(M_SQRT1_2)).erf();
out.device(d) = (casted_x * static_cast<float>(0.5) *
(static_cast<float>(1) + temp))
.template cast<T>();
} else {
auto temp = (x * static_cast<T>(M_SQRT1_2)).erf();
out.device(d) = x * static_cast<T>(0.5) * (static_cast<T>(1) + temp);
}
#endif
}
}
};
namespace phi {
template <typename T>
struct GeluGradFunctor {
template <typename Device, typename X, typename dOut, typename dX>
void operator()(Device d, X x, dOut dout, dX dx, bool approximate) const {
if (approximate) {
if (std::is_same<T, platform::float16>::value) {
if (std::is_same<T, dtype::float16>::value) {
VLOG(4) << "cast from float16 to float before computing";
auto casted_x = x.template cast<float>();
auto casted_dout = dout.template cast<float>();
......@@ -138,8 +74,8 @@ struct GeluGradFunctor {
std::memset(second, 0, n * sizeof(T));
// first = (0.5 * (1 + erf(x / sqrt(2))))
phi::funcs::CBlas<T>::AXPY(n, static_cast<T>(M_SQRT1_2), x_data, 1, first,
1);
phi::funcs::CBlas<T>::AXPY(
n, static_cast<T>(M_SQRT1_2), x_data, 1, first, 1);
phi::funcs::CBlas<T>::VMERF(n, first, first, VML_LA);
for (int i = 0; i < n; i++) {
first[i] += static_cast<T>(1);
......@@ -163,7 +99,7 @@ struct GeluGradFunctor {
#else
// gelu_grad(x) = dout * 0.5 * (1 + erf(x / sqrt(2)) + x * sqrt(2 / pi) *
// exp(- x^2 / 2)
if (std::is_same<T, platform::float16>::value) {
if (std::is_same<T, dtype::float16>::value) {
VLOG(4) << "cast from float16 to float before computing";
auto casted_x = x.template cast<float>();
auto casted_dout = dout.template cast<float>();
......@@ -188,46 +124,23 @@ struct GeluGradFunctor {
}
};
template <typename DeviceContext, typename T>
class GeluKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* out = context.Output<framework::Tensor>("Out");
auto* in = context.Input<framework::Tensor>("X");
auto approximate = context.Attr<bool>("approximate");
out->mutable_data<T>(in->place());
auto eigen_out = framework::EigenVector<T>::Flatten(*out);
auto eigen_in = framework::EigenVector<T>::Flatten(*in);
auto& place =
*context.template device_context<DeviceContext>().eigen_device();
GeluFunctor<T> functor;
functor(place, eigen_in, eigen_out, approximate);
}
};
template <typename DeviceContext, typename T>
class GeluGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* x = context.Input<framework::Tensor>("X");
auto* dout =
context.Input<framework::Tensor>(framework::GradVarName("Out"));
auto* dx = context.Output<framework::Tensor>(framework::GradVarName("X"));
auto approximate = context.Attr<bool>("approximate");
dx->mutable_data<T>(dout->place());
auto eigen_x = framework::EigenVector<T>::Flatten(*x);
auto eigen_dout = framework::EigenVector<T>::Flatten(*dout);
auto eigen_dx = framework::EigenVector<T>::Flatten(*dx);
auto& place =
*context.template device_context<DeviceContext>().eigen_device();
GeluGradFunctor<T> functor;
functor(place, eigen_x, eigen_dout, eigen_dx, approximate);
}
};
} // namespace operators
} // namespace paddle
template <typename T, typename Context>
void GeluGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& out_grad,
bool approximate,
DenseTensor* x_grad) {
dev_ctx.template Alloc<T>(x_grad);
auto eigen_x = EigenVector<T>::Flatten(x);
auto eigen_out_grad = EigenVector<T>::Flatten(out_grad);
auto eigen_x_grad = EigenVector<T>::Flatten(*x_grad);
auto& dev = *dev_ctx.eigen_device();
GeluGradFunctor<T> functor;
functor(dev, eigen_x, eigen_out_grad, eigen_x_grad, approximate);
}
} // namespace phi
PD_REGISTER_KERNEL(
gelu_grad, CPU, ALL_LAYOUT, phi::GeluGradKernel, float, double) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/gelu_kernel.h"
#include <algorithm>
#include <cmath>
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/blas/blas.h"
#include "paddle/phi/kernels/funcs/blas/blas_impl.h"
#include "paddle/phi/kernels/funcs/eigen/common.h"
#include "paddle/phi/kernels/funcs/eigen/eigen_function.h"
namespace phi {
template <typename T>
struct GeluFunctor {
template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Out out, bool approximate) const {
if (approximate) {
// gelu(x) = 0.5 * x * (1 + tanh(sqrt(2 / \pi) * (x + 0.044715 * x^{3})))
if (std::is_same<T, dtype::float16>::value) {
VLOG(4) << "cast from float16 to float before computing";
auto casted_x = x.template cast<float>();
auto temp =
(static_cast<float>(M_2_SQRTPI * M_SQRT1_2) *
(casted_x + static_cast<float>(GELU_CONSTANT) * casted_x.cube()))
.tanh();
out.device(d) = (casted_x * static_cast<float>(0.5) *
(static_cast<float>(1) + temp))
.template cast<T>();
} else {
auto temp = (static_cast<T>(M_2_SQRTPI * M_SQRT1_2) *
(x + static_cast<T>(GELU_CONSTANT) * x.cube()))
.tanh();
out.device(d) = x * static_cast<T>(0.5) * (static_cast<T>(1) + temp);
}
} else {
#if defined(PADDLE_WITH_MKLML) && !defined(_WIN32) && !defined(__APPLE__) && \
!defined(__OSX__) && !defined(PADDLE_WITH_CUDA) && \
!defined(PADDLE_WITH_HIP)
auto x_data = x.data();
auto out_data = out.data();
int n = std::min(x.size(), out.size());
std::memset(out_data, 0, n * sizeof(T));
phi::funcs::CBlas<T>::AXPY(
n, static_cast<T>(M_SQRT1_2), x_data, 1, out_data, 1);
phi::funcs::CBlas<T>::VMERF(n, out_data, out_data, VML_LA);
for (int i = 0; i < n; i++) {
out_data[i] += static_cast<T>(1);
}
phi::funcs::CBlas<T>::VMUL(n, x_data, out_data, out_data);
for (int i = 0; i < n; i++) {
out_data[i] *= static_cast<T>(0.5);
}
#else
// gelu(x) = 0.5 * x * (1 + erf(x / sqrt(2)))
if (std::is_same<T, dtype::float16>::value) {
VLOG(4) << "cast from float16 to float before computing";
auto casted_x = x.template cast<float>();
auto temp = (casted_x * static_cast<float>(M_SQRT1_2)).erf();
out.device(d) = (casted_x * static_cast<float>(0.5) *
(static_cast<float>(1) + temp))
.template cast<T>();
} else {
auto temp = (x * static_cast<T>(M_SQRT1_2)).erf();
out.device(d) = x * static_cast<T>(0.5) * (static_cast<T>(1) + temp);
}
#endif
}
}
};
template <typename T, typename Context>
void GeluKernel(const Context& dev_ctx,
const DenseTensor& x,
bool approximate,
DenseTensor* out) {
dev_ctx.template Alloc<T>(out);
auto eigen_out = EigenVector<T>::Flatten(*out);
auto eigen_x = EigenVector<T>::Flatten(x);
auto& dev = *dev_ctx.eigen_device();
GeluFunctor<T> functor;
functor(dev, eigen_x, eigen_out, approximate);
}
} // namespace phi
PD_REGISTER_KERNEL(gelu, CPU, ALL_LAYOUT, phi::GeluKernel, float, double) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/log_softmax_grad_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/axis_utils.h"
#include "paddle/phi/kernels/funcs/eigen/common.h"
#include "paddle/phi/kernels/funcs/eigen/eigen_function.h"
namespace phi {
template <typename T,
int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenMatrixTemplate = EigenMatrix<T, MajorType, IndexType>;
template <typename Context, typename T>
struct LogSoftmaxGradFunctor {
void operator()(const Context& context,
const DenseTensor* Y,
const DenseTensor* dY,
DenseTensor* dX,
const int axis) {
constexpr int kBatchDim = 0;
constexpr int kClassDim = 1;
const int n = funcs::SizeToAxis(axis, Y->dims());
const int d = funcs::SizeFromAxis(axis, Y->dims());
phi::DDim dim_2d{n, d};
auto y = EigenMatrixTemplate<T>::From(*Y, dim_2d);
auto dy = EigenMatrixTemplate<T>::From(*dY, dim_2d);
auto dx = EigenMatrixTemplate<T>::From(*dX, dim_2d);
const int axis_dim = Y->dims()[axis];
const int batch_size = y.dimension(kBatchDim);
const int num_classes = y.dimension(kClassDim);
const int num_remain = num_classes / axis_dim;
Eigen::DSizes<int, 1> along_class(kClassDim);
Eigen::DSizes<int, 3> batch_axis_remain(batch_size, axis_dim, num_remain);
Eigen::DSizes<int, 2> one_axis(1, axis_dim);
dx.device(*context.eigen_device()) =
dy -
(y.exp()) * (dy.reshape(batch_axis_remain)
.sum(along_class)
.broadcast(one_axis));
}
};
template <typename T, typename Context>
void LogSoftmaxGradKernel(const Context& dev_ctx,
const DenseTensor& out,
const DenseTensor& out_grad,
int axis,
DenseTensor* x_grad) {
const int rank = out.dims().size();
const int canonical_axis = funcs::CanonicalAxis(axis, rank);
dev_ctx.template Alloc<T>(x_grad);
if (out.numel() != 0) {
LogSoftmaxGradFunctor<Context, T>()(
dev_ctx, &out, &out_grad, x_grad, canonical_axis);
}
}
} // namespace phi
PD_REGISTER_KERNEL(log_softmax_grad,
CPU,
ALL_LAYOUT,
phi::LogSoftmaxGradKernel,
float,
double) {}
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/framework/op_registry.h"
namespace paddle {
namespace operators {
template <typename T, int MajorType = Eigen::RowMajor,
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/log_softmax_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/axis_utils.h"
#include "paddle/phi/kernels/funcs/eigen/common.h"
#include "paddle/phi/kernels/funcs/eigen/eigen_function.h"
namespace phi {
template <typename T,
int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;
static inline int CanonicalAxis(const int axis, const int rank) {
if (axis < 0) {
return axis + rank;
}
return axis;
}
static inline size_t SizeToAxis(const int axis, const framework::DDim dims) {
size_t size = 1;
for (int i = 0; i < axis; i++) {
size *= dims[i];
}
return size;
}
static inline size_t SizeFromAxis(const int axis, const framework::DDim dims) {
size_t size = 1;
for (int i = axis; i < dims.size(); i++) {
size *= dims[i];
}
return size;
}
using EigenMatrixTemplate = EigenMatrix<T, MajorType, IndexType>;
template <typename T>
struct ValueClip {
......@@ -53,21 +35,23 @@ struct ValueClip {
}
};
template <typename DeviceContext, typename T>
template <typename Context, typename T>
struct LogSoftmaxFunctor {
void operator()(const DeviceContext& context, const framework::Tensor* X,
framework::Tensor* Y, const int axis) {
void operator()(const Context& context,
const DenseTensor* X,
DenseTensor* Y,
const int axis) {
constexpr int kBatchDim = 0;
constexpr int kClassDim = 1;
constexpr int kAxisDim = 1;
int axis_dim = X->dims()[axis];
const int n = SizeToAxis(axis, X->dims());
const int d = SizeFromAxis(axis, X->dims());
framework::DDim dim_2d{n, d};
const int n = funcs::SizeToAxis(axis, X->dims());
const int d = funcs::SizeFromAxis(axis, X->dims());
phi::DDim dim_2d{n, d};
auto logits = EigenMatrix<T>::From(*X, dim_2d);
auto log_softmax = EigenMatrix<T>::From(*Y, dim_2d);
auto logits = EigenMatrixTemplate<T>::From(*X, dim_2d);
auto log_softmax = EigenMatrixTemplate<T>::From(*Y, dim_2d);
const int batch_size = logits.dimension(kBatchDim);
const int num_classes = logits.dimension(kClassDim);
......@@ -119,79 +103,21 @@ struct LogSoftmaxFunctor {
}
};
template <typename DeviceContext, typename T>
class LogSoftmaxKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* X = context.Input<framework::Tensor>("X");
auto* Out = context.Output<framework::Tensor>("Out");
const int rank = X->dims().size();
const int axis = CanonicalAxis(context.Attr<int>("axis"), rank);
// allocate memory on device.
Out->mutable_data<T>(context.GetPlace());
if (X->numel() != 0) {
LogSoftmaxFunctor<DeviceContext, T>()(
context.template device_context<DeviceContext>(), X, Out, axis);
}
}
};
template <typename DeviceContext, typename T>
struct LogSoftmaxGradFunctor {
void operator()(const DeviceContext& context, const framework::Tensor* Y,
const framework::Tensor* dY, framework::Tensor* dX,
const int axis) {
constexpr int kBatchDim = 0;
constexpr int kClassDim = 1;
const int n = SizeToAxis(axis, Y->dims());
const int d = SizeFromAxis(axis, Y->dims());
framework::DDim dim_2d{n, d};
auto y = EigenMatrix<T>::From(*Y, dim_2d);
auto dy = EigenMatrix<T>::From(*dY, dim_2d);
auto dx = EigenMatrix<T>::From(*dX, dim_2d);
const int axis_dim = Y->dims()[axis];
const int batch_size = y.dimension(kBatchDim);
const int num_classes = y.dimension(kClassDim);
const int num_remain = num_classes / axis_dim;
Eigen::DSizes<int, 1> along_class(kClassDim);
Eigen::DSizes<int, 3> batch_axis_remain(batch_size, axis_dim, num_remain);
Eigen::DSizes<int, 2> one_axis(1, axis_dim);
dx.device(*context.eigen_device()) =
dy -
(y.exp()) * (dy.reshape(batch_axis_remain)
.sum(along_class)
.broadcast(one_axis));
template <typename T, typename Context>
void LogSoftmaxKernel(const Context& dev_ctx,
const DenseTensor& x,
int axis,
DenseTensor* out) {
const int rank = x.dims().size();
const int canonical_axis = funcs::CanonicalAxis(axis, rank);
dev_ctx.template Alloc<T>(out);
if (x.numel() != 0) {
LogSoftmaxFunctor<Context, T>()(dev_ctx, &x, out, canonical_axis);
}
};
}
template <typename DeviceContext, typename T>
class LogSoftmaxGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* Out = context.Input<framework::Tensor>("Out");
auto* dOut =
context.Input<framework::Tensor>(framework::GradVarName("Out"));
auto* dX = context.Output<framework::Tensor>(framework::GradVarName("X"));
const int rank = Out->dims().size();
const int axis = CanonicalAxis(context.Attr<int>("axis"), rank);
// allocate memory on device.
dX->mutable_data<T>(context.GetPlace());
if (Out->numel() != 0) {
LogSoftmaxGradFunctor<DeviceContext, T>()(
context.template device_context<DeviceContext>(), Out, dOut, dX,
axis);
}
}
};
} // namespace phi
} // namespace operators
} // namespace paddle
PD_REGISTER_KERNEL(
log_softmax, CPU, ALL_LAYOUT, phi::LogSoftmaxKernel, float, double) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/prelu_grad_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
namespace phi {
template <typename T, typename Context>
void PReluGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& alpha,
const DenseTensor& out_grad,
const std::string& mode,
const std::string& data_format,
DenseTensor* x_grad,
DenseTensor* alpha_grad) {
const T* alpha_ptr = alpha.data<T>();
const T* x_ptr = x.data<T>();
const T* out_grad_ptr = out_grad.data<T>();
int numel = x.numel();
auto dim = x.dims();
int index = 0;
int i = 0;
if (x_grad) {
T* x_grad_ptr = dev_ctx.template Alloc<T>(x_grad);
if (mode == "channel") {
if (data_format == "NCHW") {
int temp = 1;
for (int j = 2; j < dim.size(); j++) {
temp *= dim[j];
}
for (i = 0; i < numel; i++) {
index = (i / temp) % dim[1];
x_grad_ptr[i] = x_ptr[i] > 0 ? out_grad_ptr[i]
: alpha_ptr[index] * out_grad_ptr[i];
}
} else {
for (i = 0; i < numel; i++) {
index = i % dim[dim.size() - 1];
x_grad_ptr[i] = x_ptr[i] > 0 ? out_grad_ptr[i]
: alpha_ptr[index] * out_grad_ptr[i];
}
}
} else if (mode == "element") {
int temp = 1;
for (int j = 1; j < dim.size(); j++) {
temp *= dim[j];
}
for (i = 0; i < numel; i++) {
index = i % temp;
x_grad_ptr[i] =
x_ptr[i] > 0 ? out_grad_ptr[i] : alpha_ptr[index] * out_grad_ptr[i];
}
} else {
for (i = 0; i < numel; i++) {
x_grad_ptr[i] =
x_ptr[i] > 0 ? out_grad_ptr[i] : alpha_ptr[0] * out_grad_ptr[i];
}
}
}
index = 0;
if (alpha_grad) {
T* alpha_grad_ptr = dev_ctx.template Alloc<T>(alpha_grad);
memset(alpha_grad_ptr, 0, sizeof(T) * alpha_grad->numel());
if (mode == "channel") {
if (data_format == "NCHW") {
int temp = 1;
for (int j = 2; j < dim.size(); j++) {
temp *= dim[j];
}
for (i = 0; i < numel; i++) {
index = (i / temp) % dim[1];
alpha_grad_ptr[index] +=
x_ptr[i] > 0 ? 0 : x_ptr[i] * out_grad_ptr[i];
}
} else {
for (i = 0; i < numel; i++) {
index = i % dim[dim.size() - 1];
alpha_grad_ptr[index] +=
x_ptr[i] > 0 ? 0 : x_ptr[i] * out_grad_ptr[i];
}
}
} else if (mode == "element") {
int temp = 1;
for (int j = 1; j < dim.size(); j++) {
temp *= dim[j];
}
for (i = 0; i < numel; i++) {
index = i % temp;
alpha_grad_ptr[index] += x_ptr[i] > 0 ? 0 : x_ptr[i] * out_grad_ptr[i];
}
} else {
for (i = 0; i < numel; i++) {
alpha_grad_ptr[0] += x_ptr[i] > 0 ? 0 : x_ptr[i] * out_grad_ptr[i];
}
}
}
}
} // namespace phi
PD_REGISTER_KERNEL(
prelu_grad, CPU, ALL_LAYOUT, phi::PReluGradKernel, float, double) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/prelu_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
namespace phi {
template <typename T, typename Context>
void PReluKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& alpha,
const std::string& mode,
const std::string& data_format,
DenseTensor* out) {
const T* x_ptr = x.data<T>();
const T* alpha_ptr = alpha.data<T>();
T* o_ptr = dev_ctx.template Alloc<T>(out);
int numel = x.numel();
auto dim = x.dims();
int index = 0;
int i = 0;
if (mode == "channel") {
if (data_format == "NCHW") {
int temp = 1;
for (int j = 2; j < dim.size(); j++) {
temp *= dim[j];
}
for (i = 0; i < numel; i++) {
index = (i / temp) % dim[1];
o_ptr[i] = x_ptr[i] > 0 ? x_ptr[i] : alpha_ptr[index] * x_ptr[i];
}
} else {
for (i = 0; i < numel; i++) {
index = i % dim[dim.size() - 1];
o_ptr[i] = x_ptr[i] > 0 ? x_ptr[i] : alpha_ptr[index] * x_ptr[i];
}
}
} else if (mode == "element") {
int temp = 1;
for (int j = 1; j < dim.size(); j++) {
temp *= dim[j];
}
for (i = 0; i < numel; i++) {
index = i % temp;
o_ptr[i] = x_ptr[i] > 0 ? x_ptr[i] : alpha_ptr[index] * x_ptr[i];
}
} else {
for (i = 0; i < numel; i++) {
o_ptr[i] = x_ptr[i] > 0 ? x_ptr[i] : alpha_ptr[0] * x_ptr[i];
}
}
}
} // namespace phi
PD_REGISTER_KERNEL(prelu, CPU, ALL_LAYOUT, phi::PReluKernel, float, double) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#ifndef _USE_MATH_DEFINES
#define _USE_MATH_DEFINES // use M_2_SQRTPI on Windows
#endif
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
template <typename T, typename Context>
void GeluGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& out_grad,
bool approximate,
DenseTensor* x_grad);
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#ifndef _USE_MATH_DEFINES
#define _USE_MATH_DEFINES // use M_2_SQRTPI on Windows
#endif
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
#define GELU_CONSTANT 0.044715
template <typename T, typename Context>
void GeluKernel(const Context& dev_ctx,
const DenseTensor& x,
bool approximate,
DenseTensor* out);
} // namespace phi
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/amp/fp16_type_traits.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h"
#include "paddle/fluid/operators/gelu_op.h"
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/platform/flags.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/amp_type_traits.h"
#include "paddle/phi/common/place.h"
#include "paddle/phi/kernels/funcs/aligned_vector.h"
DECLARE_bool(use_fast_math);
namespace paddle {
namespace operators {
namespace phi {
#ifdef __NVCC__
template <bool FastMode>
......@@ -52,7 +55,8 @@ static __device__ __forceinline__ float FP32GeluBwd(float x, float y_g) {
}
template <int VecSize, bool FastMode>
static __global__ void FP16FastGeluFwdCUDAKernel(const __half* x, __half* y,
static __global__ void FP16FastGeluFwdCUDAKernel(const __half* x,
__half* y,
size_t n) {
size_t offset =
static_cast<size_t>(threadIdx.x + blockIdx.x * blockDim.x) * VecSize;
......@@ -71,7 +75,8 @@ static __global__ void FP16FastGeluFwdCUDAKernel(const __half* x, __half* y,
template <int VecSize, bool FastMode>
static __global__ void FP16FastGeluBwdCUDAKernel(const __half* x,
const __half* y_g, __half* x_g,
const __half* y_g,
__half* x_g,
size_t n) {
size_t offset =
static_cast<size_t>(threadIdx.x + blockIdx.x * blockDim.x) * VecSize;
......@@ -94,8 +99,7 @@ static __global__ void FP16FastGeluBwdCUDAKernel(const __half* x,
}
static bool TryLaunchFP16FastGeluFwdVectorizeCUDAKernel(
const platform::CUDADeviceContext& dev_ctx, const __half* x, __half* y,
size_t n) {
const GPUContext& dev_ctx, const __half* x, __half* y, size_t n) {
auto is_aligned = [](const void* p, size_t alignment) {
return reinterpret_cast<uintptr_t>(p) % alignment == 0;
};
......@@ -129,8 +133,11 @@ static bool TryLaunchFP16FastGeluFwdVectorizeCUDAKernel(
}
static bool TryLaunchFP16FastGeluBwdVectorizeCUDAKernel(
const platform::CUDADeviceContext& dev_ctx, const __half* x,
const __half* y_g, __half* x_g, size_t n) {
const GPUContext& dev_ctx,
const __half* x,
const __half* y_g,
__half* x_g,
size_t n) {
auto is_aligned = [](const void* p, size_t alignment) {
return reinterpret_cast<uintptr_t>(p) % alignment == 0;
};
......@@ -149,8 +156,8 @@ static bool TryLaunchFP16FastGeluBwdVectorizeCUDAKernel(
<< " , thread = " << thread; \
FP16FastGeluBwdCUDAKernel< \
__vec_size, \
__use_fast_math><<<block, thread, 0, dev_ctx.stream()>>>(x, y_g, \
x_g, n); \
__use_fast_math><<<block, thread, 0, dev_ctx.stream()>>>( \
x, y_g, x_g, n); \
return true; \
} \
} while (0)
......@@ -166,155 +173,4 @@ static bool TryLaunchFP16FastGeluBwdVectorizeCUDAKernel(
}
#endif
template <typename T>
struct GeluWithApproximateFunctor {
using MPType = typename details::MPTypeTrait<T>::Type;
inline HOSTDEVICE T operator()(T arg_x) {
// this function is tanh approximation of gelu
MPType x = static_cast<MPType>(arg_x);
MPType one = static_cast<MPType>(1);
MPType half = static_cast<MPType>(0.5);
MPType kAlpha = static_cast<MPType>(M_2_SQRTPI * M_SQRT1_2);
auto tanh_out =
tanh(kAlpha * x * (one + static_cast<MPType>(GELU_CONSTANT) * x * x));
MPType out = x * half * (one + tanh_out);
return static_cast<T>(out);
}
};
template <typename T>
struct GeluWithoutApproximateFunctor {
using MPType = typename details::MPTypeTrait<T>::Type;
inline HOSTDEVICE T operator()(T arg_x) {
// actual gelu with approximation = false
MPType x = static_cast<MPType>(arg_x);
return static_cast<T>(x * normcdf(x));
}
};
template <typename T>
class GeluKernel<platform::CUDADeviceContext, T>
: public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* out = context.Output<framework::Tensor>("Out");
auto* in = context.Input<framework::Tensor>("X");
auto approximate = context.Attr<bool>("approximate");
out->mutable_data<T>(in->place());
std::vector<const framework::Tensor*> ins = {in};
std::vector<framework::Tensor*> outs = {out};
const auto& dev_ctx =
context.template device_context<platform::CUDADeviceContext>();
if (approximate) {
#ifdef __NVCC__
if (std::is_same<T, platform::float16>::value) {
size_t n = in->numel();
const auto* in_ptr = reinterpret_cast<const __half*>(in->data<T>());
auto* out_ptr = reinterpret_cast<__half*>(out->data<T>());
if (TryLaunchFP16FastGeluFwdVectorizeCUDAKernel(dev_ctx, in_ptr,
out_ptr, n)) {
return;
}
}
#endif
paddle::operators::LaunchElementwiseCudaKernel<ElementwiseType::kBinary,
T, T>(
dev_ctx, ins, &outs, 0, GeluWithApproximateFunctor<T>());
} else {
paddle::operators::LaunchElementwiseCudaKernel<ElementwiseType::kBinary,
T, T>(
dev_ctx, ins, &outs, 0, GeluWithoutApproximateFunctor<T>());
}
}
};
template <typename T>
struct GeluWithApproximateGradFunctor {
using MPType = typename details::MPTypeTrait<T>::Type;
inline HOSTDEVICE T operator()(T arg_x, T arg_dout) {
MPType x = static_cast<MPType>(arg_x);
MPType dout = static_cast<MPType>(arg_dout);
MPType one = static_cast<MPType>(1);
MPType half = static_cast<MPType>(0.5);
MPType kAlpha = static_cast<MPType>(M_2_SQRTPI * M_SQRT1_2);
MPType kBeta =
kAlpha * static_cast<MPType>(GELU_CONSTANT) * static_cast<MPType>(3);
auto cube_x = x * x * x;
auto tanh_out =
tanh(kAlpha * ((static_cast<MPType>(GELU_CONSTANT) * cube_x) + x));
auto ans =
half * (one + tanh_out +
(one - tanh_out * tanh_out) * (x * kAlpha + kBeta * cube_x));
return static_cast<T>(ans * dout);
}
};
template <typename T>
struct GeluWithoutApproximateGradFunctor {
using MPType = typename details::MPTypeTrait<T>::Type;
inline HOSTDEVICE T operator()(T arg_x, T arg_dout) {
MPType x = static_cast<MPType>(arg_x);
MPType dout = static_cast<MPType>(arg_dout);
constexpr MPType kBeta = M_2_SQRTPI * M_SQRT1_2 * static_cast<MPType>(0.5);
const MPType cdf = normcdf(x);
const MPType pdf = exp(static_cast<MPType>(-0.5) * x * x) * kBeta;
return static_cast<T>(dout * (cdf + x * pdf));
}
};
template <typename T>
class GeluGradKernel<platform::CUDADeviceContext, T>
: public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* x = context.Input<framework::Tensor>("X");
auto* dout =
context.Input<framework::Tensor>(framework::GradVarName("Out"));
auto* dx = context.Output<framework::Tensor>(framework::GradVarName("X"));
auto approximate = context.Attr<bool>("approximate");
dx->mutable_data<T>(dout->place());
std::vector<const framework::Tensor*> ins = {x, dout};
std::vector<framework::Tensor*> outs = {dx};
const auto& dev_ctx =
context.template device_context<platform::CUDADeviceContext>();
if (approximate) {
#ifdef __NVCC__
if (std::is_same<T, platform::float16>::value) {
size_t n = x->numel();
const auto* x_ptr = reinterpret_cast<const __half*>(x->data<T>());
const auto* y_g_ptr = reinterpret_cast<const __half*>(dout->data<T>());
auto* x_g_ptr = reinterpret_cast<__half*>(dx->data<T>());
if (TryLaunchFP16FastGeluBwdVectorizeCUDAKernel(dev_ctx, x_ptr, y_g_ptr,
x_g_ptr, n)) {
return;
}
}
#endif
paddle::operators::LaunchElementwiseCudaKernel<ElementwiseType::kBinary,
T, T>(
dev_ctx, ins, &outs, 0, GeluWithApproximateGradFunctor<T>());
} else {
paddle::operators::LaunchElementwiseCudaKernel<ElementwiseType::kBinary,
T, T>(
dev_ctx, ins, &outs, 0, GeluWithoutApproximateGradFunctor<T>());
}
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
gelu, ops::GeluKernel<paddle::platform::CUDADeviceContext, float>,
ops::GeluKernel<paddle::platform::CUDADeviceContext, double>,
ops::GeluKernel<paddle::platform::CUDADeviceContext,
paddle::platform::float16>);
REGISTER_OP_CUDA_KERNEL(
gelu_grad, ops::GeluGradKernel<paddle::platform::CUDADeviceContext, float>,
ops::GeluGradKernel<paddle::platform::CUDADeviceContext, double>,
ops::GeluGradKernel<paddle::platform::CUDADeviceContext,
paddle::platform::float16>);
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/gelu_kernel.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/amp_type_traits.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/broadcast_function.h"
#include "paddle/phi/kernels/gpu/gelu_funcs.h"
DECLARE_bool(use_fast_math);
namespace phi {
template <typename T>
struct GeluWithApproximateGradFunctor {
using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
inline HOSTDEVICE T operator()(T arg_x, T arg_dout) {
MPType x = static_cast<MPType>(arg_x);
MPType dout = static_cast<MPType>(arg_dout);
MPType one = static_cast<MPType>(1);
MPType half = static_cast<MPType>(0.5);
MPType kAlpha = static_cast<MPType>(M_2_SQRTPI * M_SQRT1_2);
MPType kBeta =
kAlpha * static_cast<MPType>(GELU_CONSTANT) * static_cast<MPType>(3);
auto cube_x = x * x * x;
auto tanh_out =
tanh(kAlpha * ((static_cast<MPType>(GELU_CONSTANT) * cube_x) + x));
auto ans =
half * (one + tanh_out +
(one - tanh_out * tanh_out) * (x * kAlpha + kBeta * cube_x));
return static_cast<T>(ans * dout);
}
};
template <typename T>
struct GeluWithoutApproximateGradFunctor {
using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
inline HOSTDEVICE T operator()(T arg_x, T arg_dout) {
MPType x = static_cast<MPType>(arg_x);
MPType dout = static_cast<MPType>(arg_dout);
constexpr MPType kBeta = M_2_SQRTPI * M_SQRT1_2 * static_cast<MPType>(0.5);
const MPType cdf = normcdf(x);
const MPType pdf = exp(static_cast<MPType>(-0.5) * x * x) * kBeta;
return static_cast<T>(dout * (cdf + x * pdf));
}
};
template <typename T, typename Context>
void GeluGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& out_grad,
bool approximate,
DenseTensor* x_grad) {
dev_ctx.template Alloc<T>(x_grad);
std::vector<const DenseTensor*> ins = {&x, &out_grad};
std::vector<DenseTensor*> outs = {x_grad};
if (approximate) {
#ifdef __NVCC__
if (std::is_same<T, dtype::float16>::value) {
size_t n = x.numel();
const auto* x_ptr = reinterpret_cast<const __half*>(x.data<T>());
const auto* y_g_ptr = reinterpret_cast<const __half*>(out_grad.data<T>());
auto* x_g_ptr = reinterpret_cast<__half*>(x_grad->data<T>());
if (TryLaunchFP16FastGeluBwdVectorizeCUDAKernel(
dev_ctx, x_ptr, y_g_ptr, x_g_ptr, n)) {
return;
}
}
#endif
phi::funcs::BroadcastKernel<ElementwiseType::kBinary, T, T>(
dev_ctx, ins, &outs, 0, GeluWithApproximateGradFunctor<T>());
} else {
phi::funcs::BroadcastKernel<ElementwiseType::kBinary, T, T>(
dev_ctx, ins, &outs, 0, GeluWithoutApproximateGradFunctor<T>());
}
}
} // namespace phi
PD_REGISTER_KERNEL(gelu_grad,
GPU,
ALL_LAYOUT,
phi::GeluGradKernel,
float,
double,
phi::dtype::float16) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/gelu_kernel.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/amp_type_traits.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/broadcast_function.h"
#include "paddle/phi/kernels/gpu/gelu_funcs.h"
DECLARE_bool(use_fast_math);
namespace phi {
template <typename T>
struct GeluWithApproximateFunctor {
using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
inline HOSTDEVICE T operator()(T arg_x) {
// this function is tanh approximation of gelu
MPType x = static_cast<MPType>(arg_x);
MPType one = static_cast<MPType>(1);
MPType half = static_cast<MPType>(0.5);
MPType kAlpha = static_cast<MPType>(M_2_SQRTPI * M_SQRT1_2);
auto tanh_out =
tanh(kAlpha * x * (one + static_cast<MPType>(GELU_CONSTANT) * x * x));
MPType out = x * half * (one + tanh_out);
return static_cast<T>(out);
}
};
template <typename T>
struct GeluWithoutApproximateFunctor {
using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
inline HOSTDEVICE T operator()(T arg_x) {
// actual gelu with approximation = false
MPType x = static_cast<MPType>(arg_x);
return static_cast<T>(x * normcdf(x));
}
};
template <typename T, typename Context>
void GeluKernel(const Context& dev_ctx,
const DenseTensor& x,
bool approximate,
DenseTensor* out) {
dev_ctx.template Alloc<T>(out);
std::vector<const DenseTensor*> ins = {&x};
std::vector<DenseTensor*> outs = {out};
if (approximate) {
#ifdef __NVCC__
if (std::is_same<T, dtype::float16>::value) {
size_t n = x.numel();
const auto* in_ptr = reinterpret_cast<const __half*>(x.data<T>());
auto* out_ptr = reinterpret_cast<__half*>(out->data<T>());
if (TryLaunchFP16FastGeluFwdVectorizeCUDAKernel(
dev_ctx, in_ptr, out_ptr, n)) {
return;
}
}
#endif
phi::funcs::BroadcastKernel<ElementwiseType::kBinary, T, T>(
dev_ctx, ins, &outs, 0, GeluWithApproximateFunctor<T>());
} else {
phi::funcs::BroadcastKernel<ElementwiseType::kBinary, T, T>(
dev_ctx, ins, &outs, 0, GeluWithoutApproximateFunctor<T>());
}
}
} // namespace phi
PD_REGISTER_KERNEL(gelu,
GPU,
ALL_LAYOUT,
phi::GeluKernel,
float,
double,
phi::dtype::float16) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/log_softmax_grad_kernel.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/gpudnn/softmax_gpudnn.h"
namespace phi {
template <typename T, typename Context>
void LogSoftmaxGradKernel(const Context &dev_ctx,
const DenseTensor &out,
const DenseTensor &out_grad,
int axis,
DenseTensor *x_grad) {
dev_ctx.template Alloc<T>(x_grad);
phi::SoftmaxBackwardCUDAKernelDriver<T, true>(
dev_ctx, out, out_grad, axis, x_grad);
}
} // namespace phi
#ifdef PADDLE_WITH_HIP
PD_REGISTER_KERNEL(log_softmax_grad,
GPU,
ALL_LAYOUT,
phi::LogSoftmaxGradKernel,
float,
phi::dtype::float16,
phi::dtype::bfloat16) {}
#else
PD_REGISTER_KERNEL(log_softmax_grad,
GPU,
ALL_LAYOUT,
phi::LogSoftmaxGradKernel,
float,
double,
phi::dtype::float16,
phi::dtype::bfloat16) {}
#endif
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/log_softmax_kernel.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/gpudnn/softmax_gpudnn.h"
namespace phi {
template <typename T, typename Context>
void LogSoftmaxKernel(const Context &dev_ctx,
const DenseTensor &x,
int axis,
DenseTensor *out) {
dev_ctx.template Alloc<T>(out);
phi::SoftmaxForwardCUDAKernelDriver<T, true>(dev_ctx, x, axis, out);
}
} // namespace phi
#ifdef PADDLE_WITH_HIP
PD_REGISTER_KERNEL(log_softmax,
GPU,
ALL_LAYOUT,
phi::LogSoftmaxKernel,
float,
phi::dtype::float16,
phi::dtype::bfloat16) {}
#else
PD_REGISTER_KERNEL(log_softmax,
GPU,
ALL_LAYOUT,
phi::LogSoftmaxKernel,
float,
double,
phi::dtype::float16,
phi::dtype::bfloat16) {}
#endif
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <string>
#include <vector>
#include "paddle/phi/kernels/funcs/math_function.h"
namespace phi {
#define CUDA_NUM_THREADS 1024
inline static int PADDLE_GET_BLOCKS(const int N) {
return (N + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS;
}
template <typename T>
__global__ void PReluChannelFirstWiseKernel(const T *input,
const T *alpha,
T *output,
size_t channel_num,
size_t plane_size,
size_t numel) {
CUDA_KERNEL_LOOP(index, numel) {
size_t temp = index / plane_size;
size_t channel_index = temp % channel_num;
T scale = alpha[channel_index];
T x = input[index];
T zero = static_cast<T>(0);
output[index] = (x > zero) ? x : scale * x;
}
}
template <typename T>
__global__ void PReluChannelLastWiseKernel(const T *input,
const T *alpha,
T *output,
size_t channel_num,
size_t numel) {
CUDA_KERNEL_LOOP(index, numel) {
size_t channel_index = index % channel_num;
T scale = alpha[channel_index];
T x = input[index];
T zero = static_cast<T>(0);
output[index] = (x > zero) ? x : scale * x;
}
}
template <typename T>
__global__ void PReluElementWiseKernel(const T *input,
const T *alpha,
T *output,
size_t spatial_size,
size_t numel) {
CUDA_KERNEL_LOOP(index, numel) {
size_t element_index = index % spatial_size;
T scale = alpha[element_index];
T x = input[index];
T zero = static_cast<T>(0);
output[index] = (x > zero) ? x : scale * x;
}
}
template <typename T>
__global__ void PReluScalarKernel(const T *input,
const T *alpha,
T *output,
size_t numel) {
T scale = alpha[0];
CUDA_KERNEL_LOOP(index, numel) {
T x = input[index];
T zero = static_cast<T>(0);
output[index] = (x > zero) ? x : scale * x;
}
}
template <typename T>
class PreluChannelWiseDirectCUDAFunctor {
public:
void operator()(gpuStream_t stream,
const T *input,
const T *alpha,
T *output,
size_t batch_size,
size_t channel,
bool channel_last,
size_t numel);
};
template <typename T>
class PreluElementWiseDirectCUDAFunctor {
public:
void operator()(gpuStream_t stream,
const T *input,
const T *alpha,
T *output,
size_t batch_size,
size_t numel);
};
template <typename T>
class PreluScalarDirectCUDAFunctor {
public:
void operator()(gpuStream_t stream,
const T *input,
const T *alpha,
T *output,
size_t numel);
};
template <typename T>
void PreluChannelWiseDirectCUDAFunctor<T>::operator()(gpuStream_t stream,
const T *input,
const T *alpha,
T *output,
size_t batch_size,
size_t channel,
bool channel_last,
size_t numel) {
if (channel_last) {
PReluChannelLastWiseKernel<<<PADDLE_GET_BLOCKS(numel),
CUDA_NUM_THREADS,
0,
stream>>>(
input, alpha, output, channel, numel);
} else {
PReluChannelFirstWiseKernel<<<PADDLE_GET_BLOCKS(numel),
CUDA_NUM_THREADS,
0,
stream>>>(
input, alpha, output, channel, numel / batch_size / channel, numel);
}
}
template <typename T>
void PreluElementWiseDirectCUDAFunctor<T>::operator()(gpuStream_t stream,
const T *input,
const T *alpha,
T *output,
size_t batch_size,
size_t numel) {
PReluElementWiseKernel<<<PADDLE_GET_BLOCKS(numel),
CUDA_NUM_THREADS,
0,
stream>>>(
input, alpha, output, numel / batch_size, numel);
}
template <typename T>
void PreluScalarDirectCUDAFunctor<T>::operator()(gpuStream_t stream,
const T *input,
const T *alpha,
T *output,
size_t numel) {
PReluScalarKernel<<<PADDLE_GET_BLOCKS(numel), CUDA_NUM_THREADS, 0, stream>>>(
input, alpha, output, numel);
}
template class PreluChannelWiseDirectCUDAFunctor<float>;
template class PreluChannelWiseDirectCUDAFunctor<phi::dtype::float16>;
template class PreluChannelWiseDirectCUDAFunctor<double>;
template class PreluElementWiseDirectCUDAFunctor<float>;
template class PreluElementWiseDirectCUDAFunctor<phi::dtype::float16>;
template class PreluElementWiseDirectCUDAFunctor<double>;
template class PreluScalarDirectCUDAFunctor<float>;
template class PreluScalarDirectCUDAFunctor<phi::dtype::float16>;
template class PreluScalarDirectCUDAFunctor<double>;
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/prelu_grad_kernel.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/tensor_meta.h"
#include "paddle/phi/kernels/empty_kernel.h"
#include "paddle/phi/kernels/funcs/reduce_function.h"
#include "paddle/phi/kernels/gpu/prelu_funcs.h"
#include "paddle/phi/kernels/primitive/functor_primitives.h"
namespace phi {
enum PRELU_MODE { Element, ChannelFirst, ChannelLast, PRELU_Scalar };
template <typename T>
__global__ void PReluOpGradKernel(const T* x_ptr,
const T* alpha_ptr,
const T* out_grad_ptr,
T* x_grad_ptr,
T* alpha_grad_ptr,
size_t channel_num,
size_t plane_size,
size_t spatial_size,
size_t numel,
PRELU_MODE mode) {
CUDA_KERNEL_LOOP(index, numel) {
T scale;
if (mode == Element) {
size_t element_index = index % spatial_size;
scale = alpha_ptr[element_index];
} else if (mode == ChannelFirst) {
size_t temp = index / plane_size;
size_t channel_index = temp % channel_num;
scale = alpha_ptr[channel_index];
} else if (mode == ChannelLast) {
size_t channel_index = index % channel_num;
scale = alpha_ptr[channel_index];
} else {
scale = alpha_ptr[0];
}
T x = x_ptr[index];
T out_grad = out_grad_ptr[index];
T zero = static_cast<T>(0);
if (x_grad_ptr != nullptr)
x_grad_ptr[index] = (x > zero) ? out_grad : scale * out_grad;
if (alpha_grad_ptr != nullptr)
alpha_grad_ptr[index] = (x > zero) ? zero : x * out_grad;
}
}
template <typename T>
class PreluOpGradFunctor {
public:
void operator()(gpuStream_t stream,
const T* x,
const T* alpha,
const T* out_grad,
T* x_grad,
T* alpha_grad,
const DDim& input_dims,
PRELU_MODE mode) {
size_t numel = 1;
for (size_t i = 0; i < input_dims.size(); ++i) {
numel *= input_dims[i];
}
size_t plane_size = numel / input_dims[0] / input_dims[1];
size_t spatial_size = numel / input_dims[0];
size_t channel =
mode == ChannelLast ? input_dims[input_dims.size() - 1] : input_dims[1];
PReluOpGradKernel<
T><<<PADDLE_GET_BLOCKS(numel), CUDA_NUM_THREADS, 0, stream>>>(
x,
alpha,
out_grad,
x_grad,
alpha_grad,
channel,
plane_size,
spatial_size,
numel,
mode);
}
};
template <typename T, typename Context>
void PReluGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& alpha,
const DenseTensor& out_grad,
const std::string& mode,
const std::string& data_format,
DenseTensor* x_grad,
DenseTensor* alpha_grad) {
dev_ctx.template Alloc<T>(x_grad);
const T* x_ptr = x.data<T>();
const T* alpha_ptr = alpha.data<T>();
const T* out_grad_ptr = out_grad.data<T>();
T* x_grad_ptr = x_grad ? dev_ctx.template Alloc<T>(x_grad) : nullptr;
T* alpha_grad_ptr =
alpha_grad ? dev_ctx.template Alloc<T>(alpha_grad) : nullptr;
if (!x_grad && !alpha_grad) return;
int numel = x.numel();
auto dim = x.dims();
auto x_rank = dim.size();
std::vector<int> input_shape = phi::vectorize<int>(dim);
auto stream = dev_ctx.stream();
T* alpha_grad_tmp_ptr;
DenseTensor alpha_grad_tmp;
if (alpha_grad_ptr == nullptr) {
alpha_grad_tmp_ptr = alpha_grad_ptr;
} else {
DenseTensorMeta alpha_grad_meta(
alpha_grad->dtype(), dim, alpha_grad->layout());
alpha_grad_tmp = phi::Empty(dev_ctx, std::move(alpha_grad_meta));
alpha_grad_tmp_ptr = alpha_grad_tmp.data<T>();
}
PRELU_MODE m;
bool channel_last = false;
if (mode == "element") {
m = Element;
} else if (mode == "channel") {
channel_last = data_format == "NHWC";
m = channel_last ? ChannelLast : ChannelFirst;
} else {
m = PRELU_Scalar;
}
PreluOpGradFunctor<T> prelu_grad;
prelu_grad(stream,
x_ptr,
alpha_ptr,
out_grad_ptr,
x_grad_ptr,
alpha_grad_tmp_ptr,
dim,
m);
if (alpha_grad_tmp_ptr == nullptr) return;
std::vector<int> reduce_dims;
for (size_t i = 0; i < dim.size(); i++) {
if (mode == "channel" && !channel_last && i == 1) continue;
if (mode == "channel" && channel_last && i == dim.size() - 1) continue;
if (mode == "element" && i != 0) continue;
reduce_dims.push_back(i);
}
phi::funcs::ReduceKernel<T, T, kps::AddFunctor, kps::IdentityFunctor<T>>(
static_cast<const phi::GPUContext&>(dev_ctx),
alpha_grad_tmp,
alpha_grad,
kps::IdentityFunctor<T>(),
reduce_dims);
}
} // namespace phi
PD_REGISTER_KERNEL(prelu_grad,
GPU,
ALL_LAYOUT,
phi::PReluGradKernel,
float,
phi::dtype::float16,
double) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/prelu_kernel.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/gpu/prelu_funcs.h"
namespace phi {
template <typename T, typename Context>
void PReluKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& alpha,
const std::string& mode,
const std::string& data_format,
DenseTensor* out) {
const T* x_ptr = x.data<T>();
T* o_ptr = dev_ctx.template Alloc<T>(out);
const T* alpha_ptr = alpha.data<T>();
int numel = x.numel();
auto dim = x.dims();
auto x_rank = dim.size();
VLOG(4) << "dim[0]:" << dim[0] << ", dim[1]:" << dim[1] << ", dim["
<< x_rank - 1 << "]:" << dim[x_rank - 1] << ", numel:" << numel;
if (mode == "channel") {
bool channel_last = data_format == "NHWC";
size_t channel = channel_last ? dim[x_rank - 1] : dim[1];
PreluChannelWiseDirectCUDAFunctor<T> prelu_channel_wise;
prelu_channel_wise(dev_ctx.stream(),
x_ptr,
alpha_ptr,
o_ptr,
dim[0],
channel,
channel_last,
numel);
} else if (mode == "element") {
PreluElementWiseDirectCUDAFunctor<T> prelu_element_wise;
prelu_element_wise(
dev_ctx.stream(), x_ptr, alpha_ptr, o_ptr, dim[0], numel);
} else {
PreluScalarDirectCUDAFunctor<T> prelu_scalar;
prelu_scalar(dev_ctx.stream(), x_ptr, alpha_ptr, o_ptr, numel);
}
}
} // namespace phi
PD_REGISTER_KERNEL(prelu,
GPU,
ALL_LAYOUT,
phi::PReluKernel,
float,
phi::dtype::float16,
double) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
template <typename T, typename Context>
void LogSoftmaxGradKernel(const Context& dev_ctx,
const DenseTensor& out,
const DenseTensor& out_grad,
int axis,
DenseTensor* x_grad);
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
template <typename T, typename Context>
void LogSoftmaxKernel(const Context& dev_ctx,
const DenseTensor& x,
int axis,
DenseTensor* out);
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
template <typename T, typename Context>
void PReluGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& alpha,
const DenseTensor& out_grad,
const std::string& mode,
const std::string& data_format,
DenseTensor* x_grad,
DenseTensor* alpha_grad);
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
template <typename T, typename Context>
void PReluKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& alpha,
const std::string& mode,
const std::string& data_format,
DenseTensor* out);
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/core/compat/op_utils.h"
namespace phi {
KernelSignature GeluOpArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature("gelu", {"X"}, {"approximate"}, {"Out"});
}
KernelSignature GeluGradOpArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature("gelu_grad",
{"X", GradVarName("Out")},
{"approximate"},
{GradVarName("X")});
}
} // namespace phi
PD_REGISTER_ARG_MAPPING_FN(gelu_grad, phi::GeluGradOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(gelu, phi::GeluOpArgumentMapping);
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/core/compat/op_utils.h"
namespace phi {
KernelSignature LogSoftmaxGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature("log_softmax_grad",
{"Out", GradVarName("Out")},
{"axis"},
{GradVarName("X")});
}
} // namespace phi
PD_REGISTER_ARG_MAPPING_FN(log_softmax_grad,
phi::LogSoftmaxGradOpArgumentMapping);
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/core/compat/op_utils.h"
namespace phi {
KernelSignature PReluGradOpArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature("prelu_grad",
{"X", "Alpha", GradVarName("Out")},
{"mode", "data_format"},
{GradVarName("X"), GradVarName("Alpha")});
}
} // namespace phi
PD_REGISTER_ARG_MAPPING_FN(prelu_grad, phi::PReluGradOpArgumentMapping);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册