未验证 提交 6eac06e3 编写于 作者: N niuliling123 提交者: GitHub

Add OpFunctor and replace cast, scale, clip, bce_loss and abs_grad with...

Add OpFunctor and replace cast, scale, clip, bce_loss and abs_grad with elementwise_no_broadcast (#38500)
上级 1345a456
...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and ...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include <algorithm> #include <algorithm>
#include "paddle/fluid/operators/bce_loss_op.h" #include "paddle/fluid/operators/bce_loss_op.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h"
#include "paddle/fluid/operators/math.h" #include "paddle/fluid/operators/math.h"
#include "paddle/fluid/platform/device/gpu/gpu_launch_config.h" #include "paddle/fluid/platform/device/gpu/gpu_launch_config.h"
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h" #include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
...@@ -23,6 +24,17 @@ namespace operators { ...@@ -23,6 +24,17 @@ namespace operators {
using Tensor = framework::Tensor; using Tensor = framework::Tensor;
template <typename T>
struct BCELossGradFunctor {
T one = static_cast<T>(1.0f);
T eps = static_cast<T>(1e-12);
__device__ __forceinline__ T operator()(const T& x, const T& label,
const T& dout) const {
T term1 = max((one - x) * x, eps);
return (dout * (x - label) / term1);
}
};
template <typename T> template <typename T>
__global__ void GPUBCELossForward(const T* x_data, const T* label_data, __global__ void GPUBCELossForward(const T* x_data, const T* label_data,
T* out_data, const int in_numel) { T* out_data, const int in_numel) {
...@@ -44,23 +56,6 @@ __global__ void GPUBCELossForward(const T* x_data, const T* label_data, ...@@ -44,23 +56,6 @@ __global__ void GPUBCELossForward(const T* x_data, const T* label_data,
} }
} }
template <typename T>
__global__ void GPUBCELossBackward(const T* x_data, const T* label_data,
const T* dout_data, T* dx_data,
const int in_numel) {
CUDA_KERNEL_LOOP(i, in_numel) {
T x = x_data[i];
T label = label_data[i];
T dout = dout_data[i];
T one = static_cast<T>(1.);
T eps = static_cast<T>(1e-12);
T term1 = max((one - x) * x, eps);
dx_data[i] = dout * (x - label) / term1;
}
}
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
class BCELossCUDAKernel : public framework::OpKernel<T> { class BCELossCUDAKernel : public framework::OpKernel<T> {
public: public:
...@@ -91,17 +86,13 @@ class BCELossGradCUDAKernel : public framework::OpKernel<T> { ...@@ -91,17 +86,13 @@ class BCELossGradCUDAKernel : public framework::OpKernel<T> {
auto* labels = ctx.Input<Tensor>("Label"); auto* labels = ctx.Input<Tensor>("Label");
auto* dout = ctx.Input<Tensor>(framework::GradVarName("Out")); auto* dout = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto* dx = ctx.Output<Tensor>(framework::GradVarName("X")); auto* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
dx->mutable_data<T>(ctx.GetPlace());
int x_numel = x->numel(); std::vector<const framework::Tensor*> ins = {x, labels, dout};
auto* dx_data = dx->mutable_data<T>(ctx.GetPlace()); std::vector<framework::Tensor*> outs = {dx};
auto& dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
auto& dev_ctx = ctx.cuda_device_context(); auto functor = BCELossGradFunctor<T>();
platform::GpuLaunchConfig config = LaunchSameDimsElementwiseCudaKernel<ElementwiseType::kTernary, T, T>(
platform::GetGpuLaunchConfig1D(dev_ctx, x_numel); dev_ctx, ins, &outs, functor);
GPUBCELossBackward<T><<<config.block_per_grid, config.thread_per_block, 0,
dev_ctx.stream()>>>(
x->data<T>(), labels->data<T>(), dout->data<T>(), dx_data, x_numel);
} }
}; };
......
...@@ -18,6 +18,9 @@ limitations under the License. */ ...@@ -18,6 +18,9 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/selected_rows_functor.h" #include "paddle/fluid/operators/math/selected_rows_functor.h"
#include "paddle/fluid/platform/transform.h" #include "paddle/fluid/platform/transform.h"
#if defined(__NVCC__) || defined(__HIPCC__)
#include "paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h"
#endif
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -25,17 +28,6 @@ namespace operators { ...@@ -25,17 +28,6 @@ namespace operators {
using framework::Tensor; using framework::Tensor;
using platform::Transform; using platform::Transform;
#if defined(__NVCC__) || defined(__HIPCC__)
template <typename T, typename UnaryOperation>
__global__ void ClipCudaKernel(const T* input, T* out, int num,
UnaryOperation op) {
int idx = threadIdx.x + blockDim.x * blockIdx.x;
if (idx < num) {
out[idx] = op(input[idx]);
}
}
#endif
template <typename T> template <typename T>
class ClipFunctor { class ClipFunctor {
public: public:
...@@ -106,12 +98,12 @@ class ClipKernel : public framework::OpKernel<T> { ...@@ -106,12 +98,12 @@ class ClipKernel : public framework::OpKernel<T> {
int64_t numel = x->numel(); int64_t numel = x->numel();
if (platform::is_gpu_place(context.GetPlace())) { if (platform::is_gpu_place(context.GetPlace())) {
#if defined(__NVCC__) || defined(__HIPCC__) #if defined(__NVCC__) || defined(__HIPCC__)
int threads = 256; std::vector<const framework::Tensor*> ins = {x};
int blocks = (numel + threads - 1) / threads; std::vector<framework::Tensor*> outs = {out};
ClipCudaKernel<T, ClipFunctor<T>><<< auto functor = ClipFunctor<T>(min, max);
blocks, threads, 0, LaunchSameDimsElementwiseCudaKernel<ElementwiseType::kUnary, T, T>(
context.template device_context<platform::CUDADeviceContext>() context.template device_context<platform::CUDADeviceContext>(), ins,
.stream()>>>(x_data, out_data, numel, ClipFunctor<T>(min, max)); &outs, functor);
#endif #endif
} else { } else {
Transform<DeviceContext> trans; Transform<DeviceContext> trans;
......
...@@ -13,19 +13,39 @@ See the License for the specific language governing permissions and ...@@ -13,19 +13,39 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/framework/tensor.h" #include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h"
#include "paddle/fluid/operators/label_smooth_op.h" #include "paddle/fluid/operators/label_smooth_op.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
template <typename T> template <typename T>
__global__ void LabelSmoothRunOriginKernel(const int N, const float epsilon, struct LabelSmoothFunctor {
const int label_dim, const T* src, T epsilon;
T* dst) { T label_dim;
CUDA_KERNEL_LOOP(idx, N) {
dst[idx] = static_cast<T>(1 - epsilon) * src[idx] + __forceinline__ LabelSmoothFunctor(float epsilon_data, int label_dim_data) {
static_cast<T>(epsilon / label_dim); epsilon = static_cast<T>(epsilon_data);
label_dim = static_cast<T>(label_dim_data);
} }
}
__device__ __forceinline__ T operator()(const T& x) const {
return (static_cast<T>(1 - epsilon) * x +
static_cast<T>(epsilon / label_dim));
}
};
template <typename T>
struct LabelSmoothGradFunctor {
T epsilon;
__forceinline__ LabelSmoothGradFunctor(float epsilon_data) {
epsilon = static_cast<T>(epsilon_data);
}
__device__ __forceinline__ T operator()(const T& x) const {
return static_cast<T>(1 - epsilon) * x;
}
};
template <typename T> template <typename T>
__global__ void LabelSmoothRunDistKernel(const int N, const float epsilon, __global__ void LabelSmoothRunDistKernel(const int N, const float epsilon,
...@@ -38,14 +58,6 @@ __global__ void LabelSmoothRunDistKernel(const int N, const float epsilon, ...@@ -38,14 +58,6 @@ __global__ void LabelSmoothRunDistKernel(const int N, const float epsilon,
} }
} }
template <typename T>
__global__ void LabelSmoothGradRunKernel(const int N, const float epsilon,
const T* src, T* dst) {
CUDA_KERNEL_LOOP(idx, N) {
dst[idx] = static_cast<T>(1 - epsilon) * src[idx];
}
}
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
class LabelSmoothGPUKernel : public framework::OpKernel<T> { class LabelSmoothGPUKernel : public framework::OpKernel<T> {
public: public:
...@@ -69,8 +81,14 @@ class LabelSmoothGPUKernel : public framework::OpKernel<T> { ...@@ -69,8 +81,14 @@ class LabelSmoothGPUKernel : public framework::OpKernel<T> {
size_prob, epsilon, dist_numel, in_data, dist_data, out_data); size_prob, epsilon, dist_numel, in_data, dist_data, out_data);
} else { } else {
LabelSmoothRunOriginKernel<T><<<grid, threads, 0, stream>>>( auto& dev_ctx =
size_prob, epsilon, label_dim, in_data, out_data); ctx.template device_context<platform::CUDADeviceContext>();
std::vector<const framework::Tensor*> ins = {in_t};
std::vector<framework::Tensor*> outs = {out_t};
auto functor = LabelSmoothFunctor<T>(epsilon, label_dim);
LaunchSameDimsElementwiseCudaKernel<ElementwiseType::kUnary, T, T>(
dev_ctx, ins, &outs, functor);
} }
} }
}; };
...@@ -84,15 +102,13 @@ class LabelSmoothGradGPUKernel : public framework::OpKernel<T> { ...@@ -84,15 +102,13 @@ class LabelSmoothGradGPUKernel : public framework::OpKernel<T> {
d_in_t->mutable_data<T>(ctx.GetPlace()); d_in_t->mutable_data<T>(ctx.GetPlace());
auto epsilon = ctx.Attr<float>("epsilon"); auto epsilon = ctx.Attr<float>("epsilon");
auto& dev = *ctx.template device_context<DeviceContext>().eigen_device(); auto& dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
const T* in_data = d_out_t->data<T>();
auto size_prob = d_out_t->numel(); std::vector<const framework::Tensor*> ins = {d_out_t};
T* out_data = d_in_t->mutable_data<T>(ctx.GetPlace()); std::vector<framework::Tensor*> outs = {d_in_t};
int threads = 512; auto functor = LabelSmoothGradFunctor<T>(epsilon);
int grid = (size_prob + threads - 1) / threads; LaunchSameDimsElementwiseCudaKernel<ElementwiseType::kUnary, T, T>(
auto stream = ctx.cuda_device_context().stream(); dev_ctx, ins, &outs, functor);
LabelSmoothGradRunKernel<T><<<grid, threads, 0, stream>>>(
size_prob, epsilon, in_data, out_data);
} }
}; };
} // namespace operators } // namespace operators
......
...@@ -19,6 +19,7 @@ ...@@ -19,6 +19,7 @@
#include "paddle/pten/core/kernel_registry.h" #include "paddle/pten/core/kernel_registry.h"
// See Note [ Why still include the fluid headers? ] // See Note [ Why still include the fluid headers? ]
#include "paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h"
#include "paddle/fluid/platform/aligned_vector.h" #include "paddle/fluid/platform/aligned_vector.h"
#include "paddle/fluid/platform/bfloat16.h" #include "paddle/fluid/platform/bfloat16.h"
#include "paddle/fluid/platform/device/gpu/gpu_helper.h" #include "paddle/fluid/platform/device/gpu/gpu_helper.h"
...@@ -27,62 +28,24 @@ ...@@ -27,62 +28,24 @@
namespace pten { namespace pten {
template <typename InT, typename OutT, int VecSize>
__global__ void VecCastCUDAKernel(const InT* in, const int64_t N, OutT* out) {
using LoadT = paddle::platform::AlignedVector<InT, VecSize>;
using StoreT = paddle::platform::AlignedVector<OutT, VecSize>;
int64_t idx = blockDim.x * blockIdx.x + threadIdx.x;
for (int64_t i = idx * VecSize; i < N;
i += blockDim.x * gridDim.x * VecSize) {
LoadT in_val;
paddle::platform::Load<InT, VecSize>(&in[i], &in_val);
StoreT out_val;
#pragma unroll
for (int j = 0; j < VecSize; j++) {
out_val[j] = static_cast<OutT>(in_val[j]);
}
paddle::platform::Store<OutT, VecSize>(out_val, &out[i]);
}
}
template <typename InT, typename OutT> template <typename InT, typename OutT>
__global__ void CastCUDAKernel(const InT* in, const int64_t N, OutT* out) { struct CastFuctor {
CUDA_KERNEL_LOOP(index, N) { out[index] = static_cast<OutT>(in[index]); } __device__ __forceinline__ OutT operator()(const InT& x) const {
} return static_cast<OutT>(x);
template <typename InT, typename OutT>
void CastCUDAKernelImplWithPtr(const GPUContext& dev_ctx,
const InT* in_data,
OutT* out_data,
int64_t size) {
paddle::platform::GpuLaunchConfig config =
paddle::platform::GetGpuLaunchConfig1D(dev_ctx, size);
int vec_size = paddle::platform::GetVectorizedSize<OutT>(out_data);
if (!std::is_same<InT, OutT>::value && vec_size == 4 && size % 4 == 0) {
VecCastCUDAKernel<InT, OutT, 4><<<config.block_per_grid,
config.thread_per_block,
0,
dev_ctx.stream()>>>(
in_data, size, out_data);
} else {
CastCUDAKernel<InT, OutT><<<config.block_per_grid,
config.thread_per_block,
0,
dev_ctx.stream()>>>(in_data, size, out_data);
} }
} };
template <typename InT, typename OutT> template <typename InT, typename OutT>
void CastCUDAKernelImpl(const GPUContext& dev_ctx, void CastCUDAKernelImpl(const GPUContext& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
DenseTensor* out) { DenseTensor* out) {
auto* in_data = x.data<InT>(); std::vector<const DenseTensor*> inputs;
auto size = x.numel(); std::vector<DenseTensor*> outputs;
auto* out_data = out->mutable_data<OutT>(); inputs.emplace_back(&x);
CastCUDAKernelImplWithPtr(dev_ctx, in_data, out_data, size); outputs.emplace_back(out);
out->mutable_data<OutT>();
LaunchSameDimsElementwiseCudaKernel<ElementwiseType::kUnary, InT, OutT>(
dev_ctx, inputs, &outputs, CastFuctor<InT, OutT>());
} }
template <typename T, typename Context> template <typename T, typename Context>
......
...@@ -16,11 +16,54 @@ limitations under the License. */ ...@@ -16,11 +16,54 @@ limitations under the License. */
#include "paddle/pten/backends/gpu/gpu_context.h" #include "paddle/pten/backends/gpu/gpu_context.h"
#include "paddle/pten/core/kernel_registry.h" #include "paddle/pten/core/kernel_registry.h"
#include "paddle/pten/kernels/impl/scale_kernel_impl.h"
// See Note [ Why still include the fluid headers? ] // See Note [ Why still include the fluid headers? ]
#include "paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h"
#include "paddle/fluid/platform/float16.h" #include "paddle/fluid/platform/float16.h"
namespace pten {
template <typename InT>
struct ScaleFunctor {
InT bias;
InT scale;
bool bias_after_scale;
ScaleFunctor(InT scale_data, InT bias_data, bool is_bias_after_sacle) {
scale = scale_data;
bias = bias_data;
bias_after_scale = is_bias_after_sacle;
}
__device__ __forceinline__ InT operator()(const InT& x) const {
if (bias_after_scale) {
return scale * x + bias;
} else {
return scale * (x + bias);
}
}
};
template <typename T, typename ContextT>
void Scale(const ContextT& dev_ctx,
const DenseTensor& x,
const Scalar& scale,
float bias,
bool bias_after_scale,
DenseTensor* out) {
std::vector<const DenseTensor*> inputs;
std::vector<DenseTensor*> outputs;
inputs.emplace_back(&x);
outputs.emplace_back(out);
out->mutable_data<T>();
LaunchSameDimsElementwiseCudaKernel<ElementwiseType::kUnary, T, T>(
dev_ctx,
inputs,
&outputs,
ScaleFunctor<T>(scale.to<T>(), static_cast<T>(bias), bias_after_scale));
}
} // namespace pten
PT_REGISTER_CTX_KERNEL(scale, PT_REGISTER_CTX_KERNEL(scale,
GPU, GPU,
ALL_LAYOUT, ALL_LAYOUT,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册