未验证 提交 99fc1b08 编写于 作者: H hong 提交者: GitHub

Move dropout to phi (#40148)

* move dropout to phi; test=develop

* fix xpu, npu compile error; test=develop
上级 843f6da0
......@@ -89,5 +89,5 @@ class DropoutOpConverter : public OpConverter {
} // namespace inference
} // namespace paddle
USE_OP(dropout);
USE_OP_ITSELF(dropout);
REGISTER_TRT_OP_CONVERTER(dropout, DropoutOpConverter);
......@@ -57,4 +57,4 @@ TEST(DropoutOpConverter, main) {
} // namespace inference
} // namespace paddle
USE_OP(dropout);
USE_OP_ITSELF(dropout);
......@@ -23,7 +23,6 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/operators/dropout_op.h"
#include "paddle/fluid/string/printf.h"
#include "paddle/phi/kernels/funcs/math_function.h"
......
......@@ -26,7 +26,6 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/operators/dropout_op.h"
#include "paddle/fluid/string/printf.h"
#include "paddle/phi/kernels/funcs/math_function.h"
......
......@@ -26,7 +26,6 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/operators/dropout_op.h"
#include "paddle/fluid/string/printf.h"
#include "paddle/phi/kernels/funcs/math_function.h"
......
......@@ -26,7 +26,6 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/operators/dropout_op.h"
#include "paddle/fluid/string/printf.h"
#include "paddle/phi/kernels/funcs/math_function.h"
......
......@@ -26,7 +26,6 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/operators/dropout_op.h"
#include "paddle/fluid/string/printf.h"
#include "paddle/phi/kernels/funcs/math_function.h"
......
......@@ -26,7 +26,6 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/operators/dropout_op.h"
#include "paddle/fluid/string/printf.h"
#include "paddle/phi/kernels/funcs/math_function.h"
......
......@@ -26,7 +26,6 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/operators/dropout_op.h"
#include "paddle/fluid/string/printf.h"
#include "paddle/phi/kernels/funcs/math_function.h"
......
......@@ -26,7 +26,6 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/operators/dropout_op.h"
#include "paddle/fluid/string/printf.h"
#include "paddle/phi/kernels/funcs/math_function.h"
......
......@@ -27,7 +27,6 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/operators/dropout_op.h"
#include "paddle/fluid/string/printf.h"
#include "paddle/phi/kernels/funcs/math_function.h"
......
......@@ -26,7 +26,6 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/operators/dropout_op.h"
#include "paddle/fluid/string/printf.h"
#include "paddle/phi/kernels/funcs/math_function.h"
......
......@@ -25,7 +25,6 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/operators/dropout_op.h"
#include "paddle/fluid/string/printf.h"
#include "paddle/phi/kernels/funcs/math_function.h"
......
......@@ -32,10 +32,9 @@ limitations under the License. */
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/operators/amp/fp16_type_traits.h"
#include "paddle/fluid/operators/dropout_impl_util.h"
#include "paddle/fluid/operators/dropout_op.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h"
#include "paddle/fluid/platform/device/gpu/gpu_launch_config.h"
#include "paddle/phi/kernels/funcs/aligned_vector.h"
#include "paddle/fluid/platform/aligned_vector.h"
#include "paddle/phi/backends/gpu/gpu_launch_config.h"
#include "paddle/phi/kernels/funcs/functors.h"
namespace paddle {
......@@ -177,12 +176,13 @@ __global__ void DropoutGradCUDAKernel(
}
template <typename T>
void DropoutFwGPUKernelDriver(const platform::CUDADeviceContext& dev_ctx,
bool is_test,
void DropoutFwGPUKernelDriver(const phi::GPUContext& dev_ctx, bool is_test,
const std::string dropout_implementation,
float dropout_prob, bool upscale_in_train,
bool is_fix_seed, int seed_val, const Tensor& x,
const Tensor* seed, Tensor* mask, Tensor* y) {
bool is_fix_seed, int seed_val,
const framework::Tensor& x,
const framework::Tensor* seed,
framework::Tensor* mask, framework::Tensor* y) {
auto& place = *dev_ctx.eigen_device();
int64_t x_numel = x.numel();
auto stream = dev_ctx.stream();
......@@ -220,7 +220,8 @@ void DropoutFwGPUKernelDriver(const platform::CUDADeviceContext& dev_ctx,
// VectorizedRandomGenerator use curand_uniform4, so we only support
// vec_size is 4;
int vec_size = (phi::GetVectorizedSize<T>(x_data) == 4) ? 4 : 1;
auto gpu_config = GetGpuLaunchConfig1D(dev_ctx, x_numel, vec_size);
auto gpu_config =
phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, x_numel, vec_size);
auto offset =
((x_numel - 1) / (gpu_config.GetThreadNum() * vec_size) + 1) * vec_size;
......@@ -278,11 +279,13 @@ void DropoutFwGPUKernelDriver(const platform::CUDADeviceContext& dev_ctx,
}
template <typename T>
void DropoutGradGPUKernelDriver(const platform::CUDADeviceContext& dev_ctx,
void DropoutGradGPUKernelDriver(const phi::GPUContext& dev_ctx,
const std::string dropout_implementation,
float dropout_prob, const Tensor& grad_y,
const Tensor& mask, int64_t size,
Tensor* grad_x, bool is_test = false) {
float dropout_prob,
const framework::Tensor& grad_y,
const framework::Tensor& mask, int64_t size,
framework::Tensor* grad_x,
bool is_test = false) {
using MT = typename details::MPTypeTrait<T>::Type;
auto stream = dev_ctx.stream();
MT factor;
......
......@@ -20,7 +20,7 @@ limitations under the License. */
namespace paddle {
namespace operators {
inline void GetSeedDataAndIncrement(const platform::CUDADeviceContext& dev_ctx,
inline void GetSeedDataAndIncrement(const phi::GPUContext& dev_ctx,
const framework::Tensor* seed,
const bool is_fix_seed, const int seed_val,
const int offset, uint64_t* seed_data,
......
......@@ -12,9 +12,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/dropout_op.h"
#include <memory>
#include <string>
#include "paddle/fluid/framework/op_registry.h"
namespace paddle {
namespace operators {
......@@ -177,14 +177,3 @@ REGISTER_OPERATOR(dropout, ops::DropoutOp, ops::DropoutOpMaker,
ops::DropoutGradOpMaker<paddle::framework::OpDesc>,
ops::DropoutGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(dropout_grad, ops::DropoutOpGrad);
REGISTER_OP_CPU_KERNEL(
dropout, ops::CPUDropoutKernel<paddle::platform::CPUDeviceContext, float>,
ops::CPUDropoutKernel<paddle::platform::CPUDeviceContext, double>,
ops::CPUDropoutKernel<paddle::platform::CPUDeviceContext,
paddle::platform::bfloat16>);
REGISTER_OP_CPU_KERNEL(
dropout_grad,
ops::DropoutGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::DropoutGradKernel<paddle::platform::CPUDeviceContext, double>,
ops::DropoutGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::bfloat16>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <string>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/dropout_impl.cu.h"
#include "paddle/fluid/operators/dropout_op.h"
#include "paddle/fluid/platform/bfloat16.h"
#include "paddle/fluid/platform/float16.h"
namespace paddle {
namespace operators {
// It seems that Eigen::Tensor::setRandom in GPU will SEGFAULT.
// Use std::random and thrust::random(thrust is a std library in CUDA) to
// implement uniform random.
template <typename Place, typename T>
class GPUDropoutKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* x = context.Input<Tensor>("X");
auto* seed =
context.HasInput("Seed") ? context.Input<Tensor>("Seed") : nullptr;
auto* y = context.Output<Tensor>("Out");
y->mutable_data<T>(context.GetPlace());
float dropout_prob = context.Attr<float>("dropout_prob");
auto& dropout_implementation =
context.Attr<std::string>("dropout_implementation");
bool upscale_in_train = (dropout_implementation == "upscale_in_train");
bool is_test = context.Attr<bool>("is_test");
auto& dev_ctx = context.cuda_device_context();
auto* mask = context.Output<Tensor>("Mask");
mask->mutable_data<uint8_t>(context.GetPlace());
bool is_fix_seed = context.Attr<bool>("fix_seed");
int seed_val = context.Attr<int>("seed");
DropoutFwGPUKernelDriver<T>(dev_ctx, is_test, dropout_implementation,
dropout_prob, upscale_in_train, is_fix_seed,
seed_val, *x, seed, mask, y);
}
};
template <typename DeviceContext, typename T>
class GPUDropoutGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* grad_x = context.Output<Tensor>(framework::GradVarName("X"));
auto* grad_y = context.Input<Tensor>(framework::GradVarName("Out"));
auto* mask = context.Input<Tensor>("Mask");
grad_x->mutable_data<T>(context.GetPlace());
auto size = grad_x->numel();
auto& dropout_implementation =
context.Attr<std::string>("dropout_implementation");
float dropout_prob = context.Attr<float>("dropout_prob");
bool is_test = context.Attr<bool>("is_test");
auto& dev_ctx =
context.template device_context<platform::CUDADeviceContext>();
DropoutGradGPUKernelDriver<T>(dev_ctx, dropout_implementation, dropout_prob,
*grad_y, *mask, size, grad_x, is_test);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(
dropout, ops::GPUDropoutKernel<plat::CUDADeviceContext, float>,
ops::GPUDropoutKernel<plat::CUDADeviceContext, plat::float16>,
ops::GPUDropoutKernel<plat::CUDADeviceContext, plat::bfloat16>,
ops::GPUDropoutKernel<plat::CUDADeviceContext, double>);
REGISTER_OP_CUDA_KERNEL(
dropout_grad, ops::GPUDropoutGradKernel<plat::CUDADeviceContext, float>,
ops::GPUDropoutGradKernel<plat::CUDADeviceContext, plat::float16>,
ops::GPUDropoutGradKernel<plat::CUDADeviceContext, plat::bfloat16>,
ops::GPUDropoutGradKernel<plat::CUDADeviceContext, double>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <cstring>
#include <random>
#include <string>
#include <algorithm>
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/generator.h"
#include "paddle/fluid/framework/op_registry.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;
template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenVector = framework::EigenVector<T, MajorType, IndexType>;
template <typename DeviceContext, typename T>
class CPUDropoutKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* x = context.Input<Tensor>("X");
auto* seed =
context.HasInput("Seed") ? context.Input<Tensor>("Seed") : nullptr;
auto* y = context.Output<Tensor>("Out");
const auto* x_data = x->data<T>();
auto* y_data = y->mutable_data<T>(context.GetPlace());
float dropout_prob = context.Attr<float>("dropout_prob");
auto& dropout_implementation =
context.Attr<std::string>("dropout_implementation");
bool upscale_in_train = (dropout_implementation == "upscale_in_train");
if (!context.Attr<bool>("is_test")) {
auto* mask = context.Output<Tensor>("Mask");
auto* mask_data = mask->mutable_data<uint8_t>(context.GetPlace());
size_t size = phi::product(mask->dims());
// Special case when dropout_prob is 1.0
if (dropout_prob == 1.0f) {
std::memset(y_data, 0, size * sizeof(*y_data)); // NOLINT
std::memset(mask_data, 0, size * sizeof(*mask_data)); // NOLINT
return;
}
// std::minstd_rand engine;
// NOTE: fixed seed should only be used in unittest or for debug.
// Guarantee to use random seed in training.
int seed_data = 0;
if (seed) {
seed_data = *(seed->data<int>());
} else {
seed_data =
context.Attr<bool>("fix_seed") ? context.Attr<int>("seed") : 0;
}
auto engine = framework::GetCPURandomEngine(seed_data);
std::uniform_real_distribution<float> dist(0, 1);
for (size_t i = 0; i < size; ++i) {
if (dist(*engine) < dropout_prob) {
mask_data[i] = 0;
y_data[i] = 0;
} else {
mask_data[i] = 1;
if (upscale_in_train) {
y_data[i] = x_data[i] / static_cast<T>(1.0f - dropout_prob);
} else {
y_data[i] = x_data[i];
}
}
}
} else {
if (upscale_in_train) {
const auto* X_data = x->data<T>();
auto* Y_data = y->mutable_data<T>(context.GetPlace());
#ifdef PADDLE_WITH_MKLML
#pragma omp parallel for
#endif
for (int i = 0; i < x->numel(); i++) {
Y_data[i] = X_data[i];
}
} else {
auto X = EigenMatrix<T>::Reshape(*x, 1);
auto Y = EigenMatrix<T>::Reshape(*y, 1);
auto& place =
*context.template device_context<DeviceContext>().eigen_device();
Y.device(place) = X * static_cast<T>(1.0f - dropout_prob);
}
}
}
};
template <typename DeviceContext, typename T>
class DropoutGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* grad_x = context.Output<Tensor>(framework::GradVarName("X"));
auto* grad_y = context.Input<Tensor>(framework::GradVarName("Out"));
auto* mask = context.Input<Tensor>("Mask");
grad_x->mutable_data<T>(context.GetPlace());
auto dX = EigenVector<T>::Flatten(*grad_x);
auto dY = EigenVector<T>::Flatten(*grad_y);
auto& place =
*context.template device_context<DeviceContext>().eigen_device();
auto& dropout_implementation =
context.Attr<std::string>("dropout_implementation");
if (context.Attr<bool>("is_test") == true) {
if (dropout_implementation == "upscale_in_train") {
dX.device(place) = static_cast<T>(1) * dY;
} else {
float dropout_prob = context.Attr<float>("dropout_prob");
dX.device(place) = dY * static_cast<T>(1.0f - dropout_prob);
}
} else {
auto M = EigenVector<uint8_t>::Flatten(*mask);
if (dropout_implementation == "upscale_in_train") {
float dropout_prob = context.Attr<float>("dropout_prob");
if (dropout_prob == 1.0f) {
dX.device(place) = static_cast<T>(0) * dY;
} else {
dX.device(place) =
dY * M.cast<T>() / static_cast<T>(1.0f - dropout_prob);
}
} else {
dX.device(place) = dY * M.cast<T>();
}
}
}
};
} // namespace operators
} // namespace paddle
......@@ -15,8 +15,8 @@ limitations under the License. */
#include <memory>
#include <string>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/operators/dropout_op.h"
#include "paddle/fluid/platform/device/npu/npu_op_runner.h"
#include "paddle/phi/core/ddim.h"
......
......@@ -24,14 +24,13 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/operators/dropout_op.h"
#include "paddle/fluid/string/printf.h"
#include "paddle/phi/kernels/funcs/math_function.h"
namespace f = paddle::framework;
namespace p = paddle::platform;
USE_OP(dropout);
USE_OP_ITSELF(dropout);
void Compare(f::Scope* scope, const p::DeviceContext& ctx) {
// init
......
......@@ -8,15 +8,17 @@ distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/dropout_op.h"
#include <memory>
#include <string>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/device/device_wrapper.h"
namespace paddle {
namespace operators {
#ifdef PADDLE_WITH_XPU
using Tensor = framework::Tensor;
template <typename DeviceContext, typename T>
class DropoutXPUKernel : public framework::OpKernel<T> {
using XPUTyp = typename XPUTypeTrait<T>::Type;
......
......@@ -24,7 +24,6 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/operators/dropout_op.h"
#include "paddle/fluid/string/printf.h"
#include "paddle/phi/kernels/funcs/math_function.h"
......
......@@ -24,7 +24,6 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/operators/dropout_op.h"
#include "paddle/fluid/string/printf.h"
#include "paddle/phi/kernels/funcs/math_function.h"
......
......@@ -140,9 +140,9 @@ class FMHARef {
if (dropout_param_.dropout_prob_) {
DropoutFwGPUKernelDriver<T>(
dev_ctx_, dropout_param_.is_test_,
static_cast<const std::string>(
dropout_param_.dropout_implementation_),
static_cast<const phi::GPUContext&>(dev_ctx_),
dropout_param_.is_test_, static_cast<const std::string>(
dropout_param_.dropout_implementation_),
dropout_param_.dropout_prob_, dropout_param_.is_upscale_in_train_,
dropout_param_.is_fix_seed_, dropout_param_.seed_val_,
static_cast<const Tensor&>(*softmax_out_tensor), dropout_param_.seed_,
......@@ -242,8 +242,9 @@ class FMHARef {
// dropout bw
if (dropout_param_.dropout_prob_) {
DropoutGradGPUKernelDriver<T>(
dev_ctx_, static_cast<const std::string>(
dropout_param_.dropout_implementation_),
static_cast<const phi::GPUContext&>(dev_ctx_),
static_cast<const std::string>(
dropout_param_.dropout_implementation_),
dropout_param_.dropout_prob_,
static_cast<const Tensor&>(*dropout_out_grad_tensor),
dropout_mask_out_tensor, softmax_out_grad_tensor->numel(),
......
......@@ -31,7 +31,7 @@ namespace framework = paddle::framework;
namespace platform = paddle::platform;
namespace memory = paddle::memory;
USE_OP(dropout);
USE_OP_ITSELF(dropout);
USE_OP(layer_norm);
template <typename T>
......
......@@ -24,7 +24,6 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/operators/dropout_op.h"
#include "paddle/fluid/string/printf.h"
#include "paddle/phi/kernels/funcs/math_function.h"
......
......@@ -24,7 +24,6 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/operators/dropout_op.h"
#include "paddle/fluid/string/printf.h"
#include "paddle/phi/kernels/funcs/math_function.h"
......
......@@ -24,7 +24,6 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/operators/dropout_op.h"
#include "paddle/fluid/string/printf.h"
#include "paddle/phi/kernels/funcs/math_function.h"
......
......@@ -16,9 +16,9 @@ limitations under the License. */
#include <type_traits>
#include <vector>
#include "paddle/fluid/framework/generator.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/activation_op.h"
#include "paddle/fluid/operators/dropout_op.h"
#include "paddle/fluid/operators/math/concat_and_split.h"
#include "paddle/fluid/operators/math/fc.h"
#include "paddle/fluid/operators/unique_op.h"
......@@ -36,6 +36,14 @@ using LoDTensor = framework::LoDTensor;
using Tensor = framework::Tensor;
using TensorList = std::vector<framework::Tensor>;
template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;
template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenVector = framework::EigenVector<T, MajorType, IndexType>;
#define DEFINE_MODE_DETECTOR(MODE_NAME, MODE_STR) \
inline bool is_##MODE_NAME(const framework::ExecutionContext& ctx) { \
const std::string& mode = ctx.Attr<std::string>("mode"); \
......
......@@ -22,7 +22,6 @@ limitations under the License. */
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/operators/dropout_op.h"
#include "paddle/fluid/string/printf.h"
#include "paddle/phi/kernels/funcs/math_function.h"
......
......@@ -24,7 +24,6 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/operators/dropout_op.h"
#include "paddle/fluid/string/printf.h"
#include "paddle/phi/kernels/funcs/math_function.h"
......
......@@ -24,7 +24,6 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/operators/dropout_op.h"
#include "paddle/fluid/string/printf.h"
#include "paddle/phi/kernels/funcs/math_function.h"
......
......@@ -24,7 +24,6 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/operators/dropout_op.h"
#include "paddle/fluid/string/printf.h"
#include "paddle/phi/kernels/funcs/math_function.h"
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/dropout_grad_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/eigen/common.h"
namespace phi {
template <typename T, typename Context>
void DropoutGradRawKernel(const Context& dev_ctx,
const DenseTensor& mask,
const DenseTensor& out_grad,
float p,
bool is_test,
const std::string& mode,
DenseTensor* x_grad) {
auto* grad_x = x_grad;
auto* grad_y = &out_grad;
grad_x->mutable_data<T>(dev_ctx.GetPlace());
auto dX = EigenVector<T>::Flatten(*grad_x);
auto dY = EigenVector<T>::Flatten(*grad_y);
auto& place = *dev_ctx.eigen_device();
auto& dropout_implementation = mode;
if (is_test == true) {
if (dropout_implementation == "upscale_in_train") {
dX.device(place) = static_cast<T>(1) * dY;
} else {
dX.device(place) = dY * static_cast<T>(1.0f - p);
}
} else {
auto M = EigenVector<uint8_t>::Flatten(mask);
if (dropout_implementation == "upscale_in_train") {
if (p == 1.0f) {
dX.device(place) = static_cast<T>(0) * dY;
} else {
dX.device(place) = dY * M.cast<T>() / static_cast<T>(1.0f - p);
}
} else {
dX.device(place) = dY * M.cast<T>();
}
}
}
} // namespace phi
PD_REGISTER_KERNEL(dropout_grad,
CPU,
ALL_LAYOUT,
phi::DropoutGradRawKernel,
float,
double,
phi::dtype::bfloat16) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/dropout_kernel.h"
#include "paddle/fluid/framework/generator.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/eigen/common.h"
namespace phi {
template <typename T, typename Context>
void DropoutRawKernel(const Context& dev_ctx,
const DenseTensor& x,
paddle::optional<const DenseTensor&> seed_tensor,
float p,
bool is_test,
const std::string& mode,
int seed,
bool fix_seed,
DenseTensor* out,
DenseTensor* mask) {
auto* y = out;
const auto* x_data = x.data<T>();
auto* y_data = y->mutable_data<T>(dev_ctx.GetPlace());
float dropout_prob = p;
auto& dropout_implementation = mode;
bool upscale_in_train = (dropout_implementation == "upscale_in_train");
if (!is_test) {
auto* mask_data = mask->mutable_data<uint8_t>(dev_ctx.GetPlace());
size_t size = phi::product(mask->dims());
// Special case when dropout_prob is 1.0
if (dropout_prob == 1.0f) {
std::memset(y_data, 0, size * sizeof(*y_data)); // NOLINT
std::memset(mask_data, 0, size * sizeof(*mask_data)); // NOLINT
return;
}
// std::minstd_rand engine;
// NOTE: fixed seed should only be used in unittest or for debug.
// Guarantee to use random seed in training.
int seed_data = 0;
if (seed_tensor.get_ptr() != nullptr) {
seed_data = *(seed_tensor->data<int>());
} else {
seed_data = fix_seed ? seed : 0;
}
auto engine = paddle::framework::GetCPURandomEngine(seed_data);
std::uniform_real_distribution<float> dist(0, 1);
for (size_t i = 0; i < size; ++i) {
if (dist(*engine) < dropout_prob) {
mask_data[i] = 0;
y_data[i] = 0;
} else {
mask_data[i] = 1;
if (upscale_in_train) {
y_data[i] = x_data[i] / static_cast<T>(1.0f - dropout_prob);
} else {
y_data[i] = x_data[i];
}
}
}
} else {
if (upscale_in_train) {
const auto* X_data = x.data<T>();
auto* Y_data = y->mutable_data<T>(dev_ctx.GetPlace());
#ifdef PADDLE_WITH_MKLML
#pragma omp parallel for
#endif
for (int i = 0; i < x.numel(); i++) {
Y_data[i] = X_data[i];
}
} else {
auto X = EigenMatrix<T>::Reshape(x, 1);
auto Y = EigenMatrix<T>::Reshape(*y, 1);
auto& place = *dev_ctx.eigen_device();
Y.device(place) = X * static_cast<T>(1.0f - dropout_prob);
}
}
}
} // namespace phi
PD_REGISTER_KERNEL(dropout,
CPU,
ALL_LAYOUT,
phi::DropoutRawKernel,
float,
double,
phi::dtype::bfloat16) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/common/scalar.h"
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
template <typename T, typename Context>
void DropoutGradRawKernel(const Context& dev_ctx,
const DenseTensor& mask,
const DenseTensor& out_grad,
float p,
bool is_test,
const std::string& mode,
DenseTensor* x_grad);
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/common/scalar.h"
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
template <typename T, typename Context>
void DropoutRawKernel(const Context& dev_ctx,
const DenseTensor& x,
paddle::optional<const DenseTensor&> seed_tensor,
float p,
bool is_test,
const std::string& mode,
int seed,
bool fix_seed,
DenseTensor* out,
DenseTensor* mask);
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/dropout_impl.cu.h"
#include "paddle/phi/kernels/dropout_grad_kernel.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
namespace phi {
template <typename T, typename Context>
void DropoutGradRawKernel(const Context& dev_ctx,
const DenseTensor& mask,
const DenseTensor& out_grad,
float p,
bool is_test,
const std::string& mode,
DenseTensor* x_grad) {
x_grad->mutable_data<T>(dev_ctx.GetPlace());
auto size = x_grad->numel();
paddle::operators::DropoutGradGPUKernelDriver<T>(
dev_ctx, mode, p, out_grad, mask, size, x_grad, is_test);
}
} // namespace phi
PD_REGISTER_KERNEL(dropout_grad,
GPU,
ALL_LAYOUT,
phi::DropoutGradRawKernel,
float,
double,
phi::dtype::bfloat16,
phi::dtype::float16) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/dropout_impl.cu.h"
#include "paddle/phi/kernels/dropout_kernel.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
namespace phi {
template <typename T, typename Context>
void DropoutRawKernel(const Context& dev_ctx,
const DenseTensor& x,
paddle::optional<const DenseTensor&> seed_tensor,
float p,
bool is_test,
const std::string& mode,
int seed,
bool fix_seed,
DenseTensor* out,
DenseTensor* mask) {
out->mutable_data<T>(dev_ctx.GetPlace());
float dropout_prob = p;
bool upscale_in_train = (mode == "upscale_in_train");
mask->mutable_data<uint8_t>(dev_ctx.GetPlace());
paddle::operators::DropoutFwGPUKernelDriver<T>(dev_ctx,
is_test,
mode,
dropout_prob,
upscale_in_train,
fix_seed,
seed,
x,
seed_tensor.get_ptr(),
mask,
out);
}
} // namespace phi
PD_REGISTER_KERNEL(dropout,
GPU,
ALL_LAYOUT,
phi::DropoutRawKernel,
float,
double,
phi::dtype::bfloat16,
phi::dtype::float16) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/core/compat/op_utils.h"
namespace phi {
KernelSignature DropoutOpArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature(
"dropout",
{"X", "Seed"},
{"dropout_prob", "is_test", "dropout_implementation", "seed", "fix_seed"},
{"Out", "Mask"});
}
KernelSignature DropoutGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature("dropout_grad",
{"Mask", GradVarName("Out")},
{"dropout_prob", "is_test", "dropout_implementation"},
{GradVarName("X")});
}
} // namespace phi
PD_REGISTER_ARG_MAPPING_FN(dropout, phi::DropoutOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(dropout_grad, phi::DropoutGradOpArgumentMapping);
......@@ -933,5 +933,65 @@ class TestDropoutWithDeterminateSeedGenerator(unittest.TestCase):
self.check_static_result(place=place)
class TestDropoutBackward(unittest.TestCase):
def setUp(self):
np.random.seed(123)
self.places = [fluid.CPUPlace()]
if core.is_compiled_with_cuda():
self.places.append(fluid.CUDAPlace(0))
def cal_grad_upscale_train(self, mask, prob):
return mask.astype("float32") / (1 - prob)
def cal_grad_downscale_in_infer(self, mask):
return mask.astype("float32")
def test_backward_downscale_in_infer(self):
for place in self.places:
with fluid.dygraph.guard(place):
input = paddle.uniform([40, 40], dtype="float32")
input.stop_gradient = False
out, mask = core.ops.dropout(input, 'dropout_prob', 0.5)
out.backward()
self.assertTrue(
np.array_equal(input.gradient(
), self.cal_grad_downscale_in_infer(mask.numpy())))
def test_backward_upscale_train(self):
for place in self.places:
with fluid.dygraph.guard(place):
prob = 0.5
input = paddle.uniform([40, 40], dtype="float32")
input.stop_gradient = False
out, mask = core.ops.dropout(input, 'dropout_prob', prob,
"dropout_implementation",
"upscale_in_train")
out.backward()
self.assertTrue(
np.allclose(input.gradient(
), self.cal_grad_upscale_train(mask.numpy(), prob)))
def test_backward_upscale_train_2(self):
for place in self.places:
with fluid.dygraph.guard(place):
prob = 0.3
input = paddle.uniform([40, 40], dtype="float32")
input.stop_gradient = False
out, mask = core.ops.dropout(input, 'dropout_prob', prob,
"dropout_implementation",
"upscale_in_train")
out.backward()
self.assertTrue(
np.allclose(input.gradient(
), self.cal_grad_upscale_train(mask.numpy(), prob)))
if __name__ == '__main__':
paddle.enable_static()
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册