未验证 提交 6ba0507d 编写于 作者: S ShenLiang 提交者: GitHub

add fused dropout add (#51752)

上级 a10718e8
......@@ -605,6 +605,16 @@
kernel :
func : frame_grad
- backward_op : fused_dropout_add_grad
forward : fused_dropout_add (Tensor x, Tensor y, Scalar p, bool is_test, str mode, int seed, bool fix_seed) -> Tensor(out), Tensor(seed_offset)
args : (Tensor seed_offset, Tensor out_grad, Scalar p, bool is_test, str mode, bool fix_seed)
output : Tensor(x_grad), Tensor(y_grad)
infer_meta :
func : GeneralBinaryGradInferMeta
param : [out_grad, out_grad]
kernel :
func : fused_dropout_add_grad
- backward_op : gather_nd_grad
forward : gather_nd (Tensor x, Tensor index) -> Tensor(out)
args : (Tensor x, Tensor index, Tensor out_grad)
......
......@@ -604,6 +604,16 @@
func : frame
backward : frame_grad
- op : fused_dropout_add
args : (Tensor x, Tensor y, Scalar p, bool is_test, str mode, int seed, bool fix_seed)
output : Tensor(out), Tensor(seed_offset)
infer_meta :
func : FusedDropoutAddInferMeta
kernel :
func : fused_dropout_add
data_type : x
backward : fused_dropout_add_grad
- op : fused_linear_param_grad_add
args : (Tensor x, Tensor dout, Tensor dweight, Tensor dbias, bool multi_precision = true)
output : Tensor(dweight_out), Tensor(dbias_out)
......
......@@ -1287,6 +1287,22 @@ void FillDiagonalTensorInferMeta(const MetaTensor& x,
out->set_dtype(x.dtype());
}
void FusedDropoutAddInferMeta(const MetaTensor& x,
const MetaTensor& y,
const Scalar& p,
bool is_test,
const std::string& mode,
int seed,
bool fix_seed,
MetaTensor* out,
MetaTensor* seed_offset) {
out->share_meta(x);
if (seed_offset) {
seed_offset->set_dims({2});
seed_offset->set_dtype(DataType::INT64);
}
}
// Used in FusedMatmulInferMeta
static std::vector<int64_t> GetInputShape(phi::DDim dim,
std::vector<int> shape,
......
......@@ -222,6 +222,16 @@ void FillDiagonalTensorInferMeta(const MetaTensor& x,
int dim2,
MetaTensor* out);
void FusedDropoutAddInferMeta(const MetaTensor& x,
const MetaTensor& y,
const Scalar& p,
bool is_test,
const std::string& mode,
int seed,
bool fix_seed,
MetaTensor* out,
MetaTensor* seed_offset);
void FusedMatmulInferMeta(const MetaTensor& x,
const MetaTensor& y,
const MetaTensor& residual_data,
......
......@@ -81,11 +81,12 @@ set(COMMON_KERNEL_DEPS
utf8proc
gather_scatter_functor)
set(COMMON_KERNEL_DEPS ${COMMON_KERNEL_DEPS} process_group)
if(WITH_FLASHATTN)
set(COMMON_KERNEL_DEPS ${COMMON_KERNEL_DEPS} phi_dynload_flashattn)
endif()
set(COMMON_KERNEL_DEPS ${COMMON_KERNEL_DEPS} process_group)
if(WITH_NCCL OR WITH_RCCL)
set(COMMON_KERNEL_DEPS ${COMMON_KERNEL_DEPS} process_group_nccl
nccl_comm_context)
......
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/common/scalar.h"
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
template <typename T, typename Context>
void FusedDropoutAddGradKernel(const Context& dev_ctx,
const DenseTensor& seed_offset,
const DenseTensor& out_grad,
const Scalar& p,
bool is_test,
const std::string& mode,
bool fix_seed,
DenseTensor* x_grad,
DenseTensor* y_grad);
} // namespace phi
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/common/scalar.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/kernels/funcs/distribution_helper.h"
namespace phi {
template <typename T, typename Context>
void FusedDropoutAddKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
const Scalar& p,
bool is_test,
const std::string& mode,
int seed,
bool fix_seed,
DenseTensor* out,
DenseTensor* seed_offset);
template <typename Context>
static inline std::vector<size_t> GetRandomCudaProp(int numel,
const Context& dev_ctx) {
constexpr int kVecSize = funcs::uniform_distribution<float>::kReturnsCount;
auto gpu_config =
backends::gpu::GetGpuLaunchConfig1D(dev_ctx, numel, kVecSize);
size_t grid_size = gpu_config.GetGridSize();
size_t block_size = gpu_config.GetBlockSize();
int64_t device_id = dev_ctx.GetPlace().GetDeviceId();
const auto& prop = phi::backends::gpu::GetDeviceProperties(device_id);
size_t max_grid_size =
prop.maxThreadsPerMultiProcessor * prop.multiProcessorCount / block_size;
grid_size = std::min(grid_size, max_grid_size);
auto offset =
((numel - 1) / (grid_size * block_size * kVecSize) + 1) * kVecSize;
size_t main_offset =
numel / (block_size * kVecSize) * (block_size * kVecSize);
return {grid_size, block_size, offset, main_offset};
}
} // namespace phi
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/fusion/fused_dropout_add_grad_kernel.h"
#include "paddle/phi/kernels/fusion/fused_dropout_add_kernel.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/distribution_helper.h"
#include "paddle/phi/kernels/funcs/dropout_impl_util.h"
#include "paddle/phi/kernels/funcs/functors.h"
#include "paddle/phi/kernels/primitive/compute_primitives.h"
#include "paddle/phi/kernels/funcs/dropout_impl.cu.h"
static constexpr int kNumCUDAThreads = 512;
static constexpr int kNumMaxinumNumBlocks = 4096;
static inline int NumBlocks(const int N) {
return std::min((N + kNumCUDAThreads - 1) / kNumCUDAThreads,
kNumMaxinumNumBlocks);
}
namespace phi {
template <typename T, typename MT>
__global__ void FuseScaleAddGrad(const T* grad,
T* x,
T* y,
const MT factor,
const int64_t limit,
bool upscale_in_train) {
CUDA_KERNEL_LOOP(i, limit) {
y[i] = grad[i];
x[i] = upscale_in_train ? grad[i]
: static_cast<T>(static_cast<MT>(grad[i]) * factor);
}
}
template <typename T>
__global__ void FuseScaleAddGradRateZero(const T* grad,
T* src,
T* res,
const int64_t limit) {
CUDA_KERNEL_LOOP(i, limit) {
res[i] = grad[i];
src[i] = 0;
}
}
template <typename T1, typename T2 = T1, typename OutT = T1>
struct NoMaskBwFunctor {
const float retain_prob_;
using MT = typename phi::kps::details::MPTypeTrait<T1>::Type;
MT factor_;
HOSTDEVICE inline NoMaskBwFunctor(const float retain_prob)
: retain_prob_(retain_prob) {
factor_ = static_cast<MT>(1.0f / retain_prob_);
}
HOSTDEVICE inline NoMaskBwFunctor(const float retain_prob, const MT factor)
: retain_prob_(retain_prob), factor_(factor) {}
HOSTDEVICE inline void operator()(OutT* dst,
const T1* src_val,
const T2* rand,
int num) const {
static constexpr int kCount =
phi::funcs::uniform_distribution<T2>::kReturnsCount;
#pragma unroll
for (int i = 0; i < kCount; i++) {
dst[i + kCount] = src_val[i];
dst[i] = rand[i] < retain_prob_
? static_cast<T1>(static_cast<MT>(src_val[i]) * factor_)
: static_cast<T1>(0);
}
}
};
template <typename T, typename Functor>
__global__ void VectorizedDropoutBackward(const size_t n,
uint64_t seed,
T* src,
T* res,
const T* dst,
uint64_t increment,
size_t main_offset,
Functor functor) {
size_t idx = static_cast<size_t>(BLOCK_ID_X * BLOCK_NUM_X);
static constexpr int kCount =
phi::funcs::uniform_distribution<float>::kReturnsCount;
size_t stride = BLOCK_NUM_X * GRID_NUM_X * kCount;
#ifdef PADDLE_WITH_HIP
hiprandStatePhilox4_32_10_t state;
hiprand_init(seed, idx + THREAD_ID_X, increment, &state);
using SType = hiprandStatePhilox4_32_10_t;
#else
curandStatePhilox4_32_10_t state;
curand_init(seed, idx + THREAD_ID_X, increment, &state);
using SType = curandStatePhilox4_32_10_t;
#endif
float rands[kCount];
T src_res[kCount * 2];
T res_grad[kCount];
using Rand = phi::funcs::uniform_distribution<float>;
using Cast = kps::IdentityFunctor<T>;
int deal_size = BLOCK_NUM_X * kCount;
size_t fix = idx * kCount;
for (; fix < main_offset; fix += stride) {
kps::ReadData<T, kCount, 1, false>(&src_res[0], dst, deal_size);
kps::ElementwiseRandom<SType, float, kCount, Rand>(
&rands[0], Rand(), &state);
// x_grad
kps::OperatorTernary<T, float, T, Functor>(
&src_res[0], &src_res[0], &rands[0], functor, kCount);
kps::WriteData<T, kCount, 1, false>(src + fix, &src_res[0], deal_size);
// res
kps::ElementwiseUnary<T, T, kCount, 1, Cast>(
&res_grad[0], &src_res[kCount], Cast());
kps::WriteData<T, kCount, 1, false>(res + fix, &res_grad[0], deal_size);
if (fix > idx * kCount + 1) {
__syncthreads();
}
}
int remainder = n - fix;
if (remainder > 0) {
kps::ReadData<T, kCount, 1, true>(&src_res[0], dst + fix, remainder);
kps::ElementwiseRandom<SType, float, kCount, Rand>(
&rands[0], Rand(), &state);
// x_grad
kps::OperatorTernary<T, float, T, Functor>(
&src_res[0], &src_res[0], &rands[0], functor, kCount);
kps::WriteData<T, kCount, 1, true>(src + fix, &src_res[0], remainder);
// res
kps::ElementwiseUnary<T, T, kCount, 1, Cast>(
&res_grad[0], &src_res[kCount], Cast());
kps::WriteData<T, kCount, 1, true>(res + fix, &res_grad[0], remainder);
__syncthreads();
}
}
template <typename T, typename Context>
void FusedDropoutAddGradKernel(const Context& dev_ctx,
const DenseTensor& seed_offset,
const DenseTensor& out_grad,
const Scalar& p,
bool is_test,
const std::string& mode,
bool fix_seed,
DenseTensor* x_grad,
DenseTensor* y_grad) {
int64_t numel = out_grad.numel();
auto stream = dev_ctx.stream();
float dropout_rate = p.to<float>();
bool upscale_in_train = (mode == "upscale_in_train");
const auto* seed_offset_data = seed_offset.data<int64_t>();
const uint64_t seed_data = static_cast<uint64_t>(seed_offset_data[0]);
const uint64_t increment = static_cast<uint64_t>(seed_offset_data[1]);
auto* x_grad_data = dev_ctx.template Alloc<T>(x_grad);
auto* y_grad_data = dev_ctx.template Alloc<T>(y_grad);
const auto* out_grad_data = out_grad.data<T>();
using MT = typename phi::kps::details::MPTypeTrait<T>::Type;
int blocks = NumBlocks(numel);
int threads = kNumCUDAThreads;
if (is_test) {
MT factor = static_cast<MT>(1.0f - dropout_rate);
FuseScaleAddGrad<T, MT><<<blocks, threads, 0, stream>>>(out_grad_data,
x_grad_data,
y_grad_data,
factor,
numel,
upscale_in_train);
} else {
if (upscale_in_train && dropout_rate == 1.0f) {
FuseScaleAddGradRateZero<T><<<blocks, threads, 0, stream>>>(
out_grad_data, x_grad_data, y_grad_data, numel);
return;
}
auto random_prop = GetRandomCudaProp(numel, dev_ctx);
size_t grid_size = random_prop[0];
size_t block_size = random_prop[1];
size_t offset = random_prop[2];
size_t main_offset = random_prop[3];
auto functor = upscale_in_train
? NoMaskBwFunctor<T, float>(1.0f - dropout_rate)
: NoMaskBwFunctor<T, float>(1.0f - dropout_rate, 1.0f);
#define PD_DROPOUT_KERNEL_NAME \
VectorizedDropoutBackward<T, NoMaskBwFunctor<T, float>>
PD_RECORD_CUDA_GRAPH_RANDOM_KERNEL(!fix_seed,
PD_DROPOUT_KERNEL_NAME,
grid_size,
block_size,
0,
stream,
offset,
KERNEL_PARAMS.As<uint64_t>(1),
KERNEL_PARAMS.As<uint64_t>(5),
numel,
seed_data, // need save
x_grad_data,
y_grad_data,
out_grad_data, // grad
increment, // need save
main_offset,
functor);
}
}
} // namespace phi
PD_REGISTER_KERNEL(fused_dropout_add_grad,
GPU,
ALL_LAYOUT,
phi::FusedDropoutAddGradKernel,
float,
double,
phi::dtype::bfloat16,
phi::dtype::float16) {
kernel->InputAt(0).SetBackend(phi::Backend::ALL_BACKEND); // seed_offset
}
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/fusion/fused_dropout_add_kernel.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/dropout_impl_util.h"
#include "paddle/phi/kernels/funcs/functors.h"
#include "paddle/phi/kernels/primitive/compute_primitives.h"
#include "paddle/phi/kernels/funcs/dropout_impl.cu.h"
namespace phi {
template <typename T1, typename T2 = T1, typename OutT = T1>
struct NoMaskFwFunctor {
const float retain_prob_;
const bool is_upscale_in_train_;
using MT = typename phi::kps::details::MPTypeTrait<T1>::Type;
MT factor;
HOSTDEVICE inline NoMaskFwFunctor(const float retain_prob,
const bool is_upscale_in_train)
: retain_prob_(retain_prob), is_upscale_in_train_(is_upscale_in_train) {
factor = static_cast<MT>(1.0f / retain_prob_);
}
HOSTDEVICE inline void operator()(OutT* dst,
const T1* src_val,
const T2* rand,
int num) const {
static constexpr int kCount =
phi::funcs::uniform_distribution<T2>::kReturnsCount;
#pragma unroll
for (int i = 0; i < kCount; i++) {
if (rand[i] < retain_prob_) {
dst[i] = is_upscale_in_train_
? static_cast<T1>(static_cast<MT>(src_val[i]) * factor)
: static_cast<T1>(src_val[i]);
dst[i] += src_val[i + kCount];
} else {
dst[i] = src_val[i + kCount];
}
}
}
};
template <typename T>
struct ScaleAddFuctor {
using MT = typename phi::kps::details::MPTypeTrait<T>::Type;
explicit ScaleAddFuctor(const MT factor, bool upscale_in_train)
: factor_(factor), upscale_in_train_(upscale_in_train) {}
__device__ __forceinline__ T operator()(const T src, const T res) const {
return upscale_in_train_
? src + res
: static_cast<T>(static_cast<MT>(src) * factor_) + res;
}
private:
MT factor_;
bool upscale_in_train_;
};
template <typename T, typename Functor>
__global__ void VectorizedDropoutForward(const size_t n,
uint64_t seed,
const T* src,
const T* res,
T* dst,
uint64_t increment,
size_t main_offset,
Functor functor) {
size_t idx = static_cast<size_t>(BLOCK_ID_X * BLOCK_NUM_X);
static constexpr int kCount =
phi::funcs::uniform_distribution<float>::kReturnsCount;
size_t stride = BLOCK_NUM_X * GRID_NUM_X * kCount;
#ifdef PADDLE_WITH_HIP
hiprandStatePhilox4_32_10_t state;
hiprand_init(seed, idx + THREAD_ID_X, increment, &state);
using SType = hiprandStatePhilox4_32_10_t;
#else
curandStatePhilox4_32_10_t state;
curand_init(seed, idx + THREAD_ID_X, increment, &state);
using SType = curandStatePhilox4_32_10_t;
#endif
T dst_res[kCount * 2];
float rands[kCount];
using Rand = phi::funcs::uniform_distribution<float>;
int deal_size = BLOCK_NUM_X * kCount;
size_t fix = idx * kCount;
for (; fix < main_offset; fix += stride) {
kps::ReadData<T, kCount, 1, false>(&dst_res[0], src + fix, deal_size);
kps::ReadData<T, kCount, 1, false>(&dst_res[kCount], res + fix, deal_size);
kps::ElementwiseRandom<SType, float, kCount, Rand>(
&rands[0], Rand(), &state);
// dst
kps::OperatorTernary<T, float, T, Functor>(
&dst_res[0], &dst_res[0], &rands[0], functor, kCount);
kps::WriteData<T, kCount, 1, false>(dst + fix, &dst_res[0], deal_size);
if (fix > idx * kCount + 1) {
__syncthreads();
}
}
int remainder = n - fix;
if (remainder > 0) {
kps::ReadData<T, kCount, 1, true>(&dst_res[0], src + fix, remainder);
kps::ReadData<T, kCount, 1, true>(&dst_res[kCount], res + fix, remainder);
kps::ElementwiseRandom<SType, float, kCount, Rand>(
&rands[0], Rand(), &state);
// dst
kps::OperatorTernary<T, float, T, Functor>(
&dst_res[0], &dst_res[0], &rands[0], functor, kCount);
kps::WriteData<T, kCount, 1, true>(dst + fix, &dst_res[0], remainder);
__syncthreads();
}
}
template <typename T, typename Context>
void FusedDropoutAddKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
const Scalar& p,
bool is_test,
const std::string& mode,
int seed,
bool fix_seed,
DenseTensor* out,
DenseTensor* seed_offset) {
auto* out_data = dev_ctx.template Alloc<T>(out);
auto* seed_offset_data = dev_ctx.template HostAlloc<int64_t>(seed_offset);
int64_t numel = x.numel();
auto stream = dev_ctx.stream();
bool upscale_in_train = (mode == "upscale_in_train");
const auto* x_data = x.data<T>();
const auto* y_data = y.data<T>();
float dropout_rate = p.to<float>();
if (!is_test) {
if (dropout_rate == 1.0f) {
phi::Copy(dev_ctx, y, dev_ctx.GetPlace(), false, out);
return;
}
uint64_t seed_data;
uint64_t increment;
auto random_prop = GetRandomCudaProp(numel, dev_ctx);
size_t grid_size = random_prop[0];
size_t block_size = random_prop[1];
size_t offset = random_prop[2];
size_t main_offset = random_prop[3];
funcs::GetSeedDataAndIncrement(
dev_ctx, nullptr, fix_seed, seed, offset, &seed_data, &increment);
seed_offset_data[0] = static_cast<int64_t>(seed_data);
seed_offset_data[1] = static_cast<int64_t>(increment);
auto dst_functor =
NoMaskFwFunctor<T, float>(1.0f - dropout_rate, upscale_in_train);
#define PD_DROPOUT_KERNEL_NAME \
VectorizedDropoutForward<T, NoMaskFwFunctor<T, float>>
PD_RECORD_CUDA_GRAPH_RANDOM_KERNEL(!fix_seed,
PD_DROPOUT_KERNEL_NAME,
grid_size,
block_size,
0,
stream,
offset,
KERNEL_PARAMS.As<uint64_t>(1),
KERNEL_PARAMS.As<uint64_t>(5),
numel,
seed_data, // need save
x_data,
y_data,
out_data,
increment, // need save
main_offset,
dst_functor);
#undef PD_DROPOUT_KERNEL_NAME
} else {
using MT = typename phi::kps::details::MPTypeTrait<T>::Type;
MT factor = static_cast<MT>(1.0f - dropout_rate);
std::vector<phi::DenseTensor*> outs = {out};
std::vector<const phi::DenseTensor*> ins = {&x, &y};
phi::funcs::ElementwiseKernel<T>(
dev_ctx, ins, &outs, ScaleAddFuctor<T>(factor, upscale_in_train));
}
}
} // namespace phi
PD_REGISTER_KERNEL(fused_dropout_add,
GPU,
ALL_LAYOUT,
phi::FusedDropoutAddKernel,
float,
double,
phi::dtype::bfloat16,
phi::dtype::float16) {
kernel->OutputAt(1).SetDataType(phi::DataType::INT64);
}
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import paddle
import paddle.fluid.core as core
from paddle import fluid
from paddle.incubate.nn.functional import fused_dropout_add
from paddle.incubate.nn.layer.fused_dropout_add import FusedDropoutAdd
def paddle_dropout_add(x, y, p=0.5, training=True, mode="upscale_in_train"):
tmp = paddle.nn.functional.dropout(x, p, training=training, mode=mode)
return tmp + y
@unittest.skipIf(
not core.is_compiled_with_cuda(),
"core is not compiled with CUDA ",
)
class TestFusedDropoutAdd(unittest.TestCase):
def setUp(self):
self.shape = (2, 10, 10, 2)
self.dtype = 'float64'
self.dropout_rate = 0.9
self.training = True
self.mode = "upscale_in_train"
self.seed = 1027
def get_paddle_tensor(self):
tmp = paddle.randn(self.shape, self.dtype)
tmp.stop_gradient = False
return tmp
def get_forward_backward(self, dropout_add, seed):
paddle.disable_static()
paddle.seed(seed)
count = 3
data = []
fw = []
bw = []
for _ in range(count):
data.append(self.get_paddle_tensor())
out = data[0]
for i in range(1, count):
out = dropout_add(
out,
data[i],
p=self.dropout_rate,
training=self.training,
mode=self.mode,
)
fw.append(out)
loss = paddle.mean(out)
loss.backward()
for i in range(count):
bw.append(data[i].grad)
return fw, bw
def test_fused_dropout_add(self):
p_fw, p_bw = self.get_forward_backward(
paddle_dropout_add, seed=self.seed
)
f_fw, f_bw = self.get_forward_backward(
fused_dropout_add, seed=self.seed
)
for i in range(len(p_fw)):
np.testing.assert_allclose(
p_fw[i].numpy(), f_fw[i].numpy(), rtol=1e-05
)
np.testing.assert_allclose(
p_bw[i].numpy(), f_bw[i].numpy(), rtol=1e-05
)
def create_test_class(parent, dtype, mode, training, p, seed):
@unittest.skipIf(
not core.is_compiled_with_cuda(), "core is not compiled with CUDA"
)
class TestFusedDropoutAddCase(parent):
def setUp(self):
self.shape = (2, 10, 10, 2)
self.dtype = dtype
self.dropout_rate = p
self.training = training
self.mode = mode
self.seed = seed
cls_name = "{0}_{1}_{2}_{3}_{4}_{5}".format(
parent.__name__, dtype, mode, str(training), str(p), str(seed)
)
TestFusedDropoutAddCase.__name__ = cls_name
globals()[cls_name] = TestFusedDropoutAddCase
for dtype in ["float64", "float32", "float16"]:
for mode in ["upscale_in_train", "downscale_in_infer"]:
for p in [0.0, 0.5, 0.9, 1.0]:
for training in [True, False]:
for seed in [0, 1024]:
create_test_class(
TestFusedDropoutAdd, dtype, mode, training, p, seed
)
@unittest.skipIf(
not core.is_compiled_with_cuda(), "core is not compiled with CUDA "
)
class TestFusedDropoutAddStatic(unittest.TestCase):
def setUp(self):
self.place = paddle.CUDAPlace(0)
self.shape = (2, 80, 8, 2)
self.dtype = 'float16'
def test_static_op(self):
paddle.disable_static()
paddle.seed(312)
x_data = np.random.random(self.shape)
y_data = np.random.random(self.shape)
x = paddle.to_tensor(
x_data, place=self.place, dtype=self.dtype, stop_gradient=False
)
y = paddle.to_tensor(
y_data, place=self.place, dtype=self.dtype, stop_gradient=False
)
out = fused_dropout_add(x, y, p=0.5, training=True)
paddle.enable_static()
paddle.seed(312)
with paddle.static.program_guard(paddle.static.Program()):
xs = paddle.static.data(
name="xs", shape=self.shape, dtype=self.dtype
)
ys = paddle.static.data(
name="ys", shape=self.shape, dtype=self.dtype
)
outs = fused_dropout_add(xs, ys, p=0.5, training=True)
exe = fluid.Executor(self.place)
out_s = exe.run(
feed={
"xs": x_data.astype('float16'),
"ys": y_data.astype('float16'),
},
fetch_list=[outs],
)
np.testing.assert_allclose(out_s[0], out)
def test_fused_dropout_add_layer(self):
x = paddle.randn(self.shape, self.dtype)
y = paddle.randn(self.shape, self.dtype)
fused_d_a = FusedDropoutAdd(p=0.5)
d = paddle.nn.Dropout(p=0.5)
print(d)
paddle.seed(2048)
fused_out = fused_d_a(x, y)
paddle.seed(2048)
out = d(x) + y
np.testing.assert_allclose(fused_out, out)
def test_assert(self):
def check_raise():
x = paddle.randn(self.shape, self.dtype)
y = paddle.randn(self.shape, self.dtype)
fused_d_a = FusedDropoutAdd(p=-1)
fused_out = fused_d_a(x, y)
self.assertRaises(ValueError, check_raise)
if __name__ == '__main__':
unittest.main()
......@@ -21,6 +21,7 @@ from .layer.fused_transformer import (
FusedBiasDropoutResidualLayerNorm,
) # noqa: F401
from .layer.fused_ec_moe import FusedEcMoe # noqa: F401
from .layer.fused_dropout_add import FusedDropoutAdd # noqa: F401
__all__ = [ # noqa
'FusedMultiHeadAttention',
......@@ -30,4 +31,5 @@ __all__ = [ # noqa
'FusedLinear',
'FusedBiasDropoutResidualLayerNorm',
'FusedEcMoe',
'FusedDropoutAdd',
]
......@@ -18,6 +18,7 @@ from .fused_transformer import fused_multi_transformer
from .fused_matmul_bias import fused_matmul_bias, fused_linear
from .fused_transformer import fused_bias_dropout_residual_layer_norm
from .fused_ec_moe import fused_ec_moe
from .fused_dropout_add import fused_dropout_add
__all__ = [
'fused_multi_head_attention',
......@@ -27,4 +28,5 @@ __all__ = [
'fused_linear',
'fused_bias_dropout_residual_layer_norm',
'fused_ec_moe',
'fused_dropout_add',
]
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from paddle import _C_ops
from paddle.common_ops_import import default_main_program
from paddle.fluid import core
from paddle.fluid.framework import in_dygraph_mode
from paddle.framework import LayerHelper
def fused_dropout_add(
x, y, p=0.5, training=True, mode='upscale_in_train', name=None
):
r"""
Fused Dropout and Add.
Args:
x (Tensor): The input tensor. The data type is bfloat16, float16, float32 or float64.
y (Tensor): The input tensor. The data type is bfloat16, float16, float32 or float64.
p (float|int, optional): Probability of setting units to zero. Default: 0.5.
training (bool, optional): A flag indicating whether it is in train phrase or not. Default: True.
mode(str, optional): ['upscale_in_train'(default) | 'downscale_in_infer'].
1. upscale_in_train (default), upscale the output at training time
- train: :math:`out = x \times \frac{mask}{(1.0 - dropout\_prob)} + y`
- inference: :math:`out = x + y`
2. downscale_in_infer, downscale the output at inference
- train: :math:`out = input \times mask + y`
- inference: :math:`out = input \times (1.0 - dropout\_prob) + y`
name (str, optional): Name for the operation, Default: None. For more information, please refer to :ref:`api_guide_Name`.
Returns:
A Tensor representing the fused dropout and add, has same shape and data type as `x` .
Examples:
.. code-block:: python
# required: gpu
import paddle
from paddle.incubate.nn.functional import fused_dropout_add
x = paddle.randn([4, 10], dtype='float16')
y = paddle.randn([4, 10], dtype='float16')
out = fused_dropout_add(x, y, p=0.5)
"""
if isinstance(p, (int, float)):
# fast return for p == 0
if p == 0:
return x + y
elif p < 0 or p > 1:
raise ValueError("p argument should between 0 and 1")
if mode not in ('downscale_in_infer', 'upscale_in_train'):
raise ValueError(
"mode argument should be 'downscale_in_infer' or 'upscale_in_train'"
)
seed = None
if in_dygraph_mode():
if default_main_program().random_seed != 0:
seed = default_main_program().random_seed
out, seed_offset = _C_ops.fused_dropout_add(
x,
y,
p,
not training,
mode,
seed if seed is not None else 0,
seed is not None,
)
return out
else:
helper = LayerHelper('fused_dropout_add', **locals())
out = helper.create_variable_for_type_inference(dtype=x.dtype)
seed_offset = helper.create_variable_for_type_inference(
dtype=core.VarDesc.VarType.INT64, stop_gradient=True
)
def get_attrs(prog, dropout_prob, is_test, seed):
if (seed is None or seed == 0) and prog.random_seed != 0:
seed = prog.random_seed
attrs = {
'p': dropout_prob,
'is_test': is_test,
'mode': mode,
'seed': seed if seed is not None else 0,
'fix_seed': seed is not None,
}
return attrs
attrs = get_attrs(helper.main_program, p, not training, seed)
helper.append_op(
type='fused_dropout_add',
inputs={'x': x, 'y': y},
outputs={'out': [out], 'seed_offset': [seed_offset]},
attrs=attrs,
)
return out
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from paddle.incubate.nn import functional as F
from paddle.nn import Layer
class FusedDropoutAdd(Layer):
r"""
Fused Dropout and Add.
Parameters:
p (float|int, optional): Probability of setting units to zero. Default: 0.5
mode(str, optional): ['upscale_in_train'(default) | 'downscale_in_infer']
1. upscale_in_train (default), upscale the output at training time
- train: :math:`out = x \times \frac{mask}{(1.0 - p)} + y`
- inference: :math:`out = x + y`
2. downscale_in_infer, downscale the output at inference
- train: :math:`out = x \times mask + y`
- inference: :math:`out = x \times (1.0 - p) + y`
name (str, optional): Name for the operation, Default: None. For more information, please refer to :ref:`api_guide_Name`.
Shape:
- x: N-D tensor.
- y: N-D tensor.
- output: N-D tensor, the same shape as x.
Examples:
.. code-block:: python
# required: gpu
import paddle
from paddle.incubate.nn.layer.fused_dropout_add import FusedDropoutAdd
x = paddle.to_tensor([[1,2,3], [4,5,6]], dtype="float32")
y = paddle.to_tensor([[1,2,3], [4,5,6]], dtype="float32")
m = FusedDropoutAdd(p=0.5)
out = m(x, y)
"""
def __init__(self, p=0.5, mode="upscale_in_train", name=None):
super().__init__()
self.p = p
self.mode = mode
self.name = name
def forward(self, x, y):
out = F.fused_dropout_add(
x,
y,
p=self.p,
training=self.training,
mode=self.mode,
name=self.name,
)
return out
def extra_repr(self):
name_str = ', name={}'.format(self.name) if self.name else ''
return 'p={}, mode={}{}'.format(self.p, self.mode, name_str)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册