未验证 提交 946dbdae 编写于 作者: Q Qi Li 提交者: GitHub

[ROCM] update fluid operators for rocm (part6), test=develop (#31301)

上级 1cbccfa5
...@@ -14,7 +14,11 @@ ...@@ -14,7 +14,11 @@
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/activation_op.h" #include "paddle/fluid/operators/activation_op.h"
#ifdef PADDLE_WITH_HIP
#include "paddle/fluid/platform/miopen_desc.h"
#else
#include "paddle/fluid/platform/cudnn_desc.h" #include "paddle/fluid/platform/cudnn_desc.h"
#endif
namespace paddle { namespace paddle {
namespace operators { namespace operators {
......
...@@ -14,7 +14,11 @@ ...@@ -14,7 +14,11 @@
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/activation_op.h" #include "paddle/fluid/operators/activation_op.h"
#ifdef PADDLE_WITH_HIP
#include "paddle/fluid/platform/miopen_desc.h"
#else
#include "paddle/fluid/platform/cudnn_desc.h" #include "paddle/fluid/platform/cudnn_desc.h"
#endif
namespace paddle { namespace paddle {
namespace platform { namespace platform {
...@@ -29,35 +33,71 @@ using platform::ActivationDescriptor; ...@@ -29,35 +33,71 @@ using platform::ActivationDescriptor;
using platform::TensorDescriptor; using platform::TensorDescriptor;
using platform::CUDADeviceContext; using platform::CUDADeviceContext;
#ifdef PADDLE_WITH_HIP
#define GPUDNN_ACTIVATION_RELU miopenActivationRELU
#define GPUDNN_ACTIVATION_CLIPPED_RELU miopenActivationCLIPPEDRELU
#define GPUDNN_ACTIVATION_SIGMOID miopenActivationLOGISTIC
#define GPUDNN_ACTIVATION_TANH miopenActivationTANH
#else
#define GPUDNN_ACTIVATION_RELU CUDNN_ACTIVATION_RELU
#define GPUDNN_ACTIVATION_CLIPPED_RELU CUDNN_ACTIVATION_CLIPPED_RELU
#define GPUDNN_ACTIVATION_SIGMOID CUDNN_ACTIVATION_SIGMOID
#define GPUDNN_ACTIVATION_TANH CUDNN_ACTIVATION_TANH
#endif
template <typename T> template <typename T>
struct CudnnActivationFunctor { struct CudnnActivationFunctor {
using ELEMENT_TYPE = T; using ELEMENT_TYPE = T;
#ifdef PADDLE_WITH_HIP
CudnnActivationFunctor(const CUDADeviceContext& ctx, const T& c,
const miopenActivationMode_t& m)
: ctx_(ctx), coef_(c), mode_(m) {}
#else
CudnnActivationFunctor(const CUDADeviceContext& ctx, const T& c, CudnnActivationFunctor(const CUDADeviceContext& ctx, const T& c,
const cudnnActivationMode_t& m) const cudnnActivationMode_t& m)
: ctx_(ctx), coef_(c), mode_(m) {} : ctx_(ctx), coef_(c), mode_(m) {}
#endif
void operator()(const Tensor& x, Tensor* out) { void operator()(const Tensor& x, Tensor* out) {
ActivationDescriptor act_desc; ActivationDescriptor act_desc;
act_desc.set(mode_, coef_); act_desc.set(mode_, coef_);
TensorDescriptor x_desc, out_desc; TensorDescriptor x_desc, out_desc;
x_desc.set(x); x_desc.set(x);
out_desc.set(GET_DATA_SAFELY(out, "Output", "Out", "CudnnActivation")); out_desc.set(GET_DATA_SAFELY(out, "Output", "Out", "CudnnActivation"));
#ifdef PADDLE_WITH_HIP
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::miopenActivationForward(
ctx_.cudnn_handle(), act_desc.desc(),
platform::CudnnDataType<T>::kOne(), x_desc.desc(), x.data<T>(),
platform::CudnnDataType<T>::kZero(), out_desc.desc(),
out->mutable_data<T>(ctx_.GetPlace())));
#else
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnActivationForward( PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnActivationForward(
ctx_.cudnn_handle(), act_desc.desc(), ctx_.cudnn_handle(), act_desc.desc(),
platform::CudnnDataType<T>::kOne(), x_desc.desc(), x.data<T>(), platform::CudnnDataType<T>::kOne(), x_desc.desc(), x.data<T>(),
platform::CudnnDataType<T>::kZero(), out_desc.desc(), platform::CudnnDataType<T>::kZero(), out_desc.desc(),
out->mutable_data<T>(ctx_.GetPlace()))); out->mutable_data<T>(ctx_.GetPlace())));
#endif
} }
const CUDADeviceContext& ctx_; const CUDADeviceContext& ctx_;
const T coef_; const T coef_;
#ifdef PADDLE_WITH_HIP
const miopenActivationMode_t mode_;
#else
const cudnnActivationMode_t mode_; const cudnnActivationMode_t mode_;
#endif
}; };
template <typename T> template <typename T>
struct CudnnActivationGradFunctor { struct CudnnActivationGradFunctor {
using ELEMENT_TYPE = T; using ELEMENT_TYPE = T;
#ifdef PADDLE_WITH_HIP
CudnnActivationGradFunctor(const CUDADeviceContext& ctx, const T& c,
const miopenActivationMode_t& m)
: ctx_(ctx), coef_(c), mode_(m) {}
#else
CudnnActivationGradFunctor(const CUDADeviceContext& ctx, const T& c, CudnnActivationGradFunctor(const CUDADeviceContext& ctx, const T& c,
const cudnnActivationMode_t& m) const cudnnActivationMode_t& m)
: ctx_(ctx), coef_(c), mode_(m) {} : ctx_(ctx), coef_(c), mode_(m) {}
#endif
void operator()(const Tensor& x, const Tensor& out, const Tensor dout, void operator()(const Tensor& x, const Tensor& out, const Tensor dout,
Tensor* dx) { Tensor* dx) {
ActivationDescriptor act_desc; ActivationDescriptor act_desc;
...@@ -67,27 +107,40 @@ struct CudnnActivationGradFunctor { ...@@ -67,27 +107,40 @@ struct CudnnActivationGradFunctor {
out_desc.set(out); out_desc.set(out);
dout_desc.set(dout); dout_desc.set(dout);
dx_desc.set(GET_DATA_SAFELY(dx, "Output", "X@GRAD", "CudnnActivationGrad")); dx_desc.set(GET_DATA_SAFELY(dx, "Output", "X@GRAD", "CudnnActivationGrad"));
#ifdef PADDLE_WITH_HIP
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::miopenActivationBackward(
ctx_.cudnn_handle(), act_desc.desc(),
platform::CudnnDataType<T>::kOne(), out_desc.desc(), out.data<T>(),
dout_desc.desc(), dout.data<T>(), x_desc.desc(), x.data<T>(),
platform::CudnnDataType<T>::kZero(), dx_desc.desc(),
dx->mutable_data<T>(ctx_.GetPlace())));
#else
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnActivationBackward( PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnActivationBackward(
ctx_.cudnn_handle(), act_desc.desc(), ctx_.cudnn_handle(), act_desc.desc(),
platform::CudnnDataType<T>::kOne(), out_desc.desc(), out.data<T>(), platform::CudnnDataType<T>::kOne(), out_desc.desc(), out.data<T>(),
dout_desc.desc(), dout.data<T>(), x_desc.desc(), x.data<T>(), dout_desc.desc(), dout.data<T>(), x_desc.desc(), x.data<T>(),
platform::CudnnDataType<T>::kZero(), dx_desc.desc(), platform::CudnnDataType<T>::kZero(), dx_desc.desc(),
dx->mutable_data<T>(ctx_.GetPlace()))); dx->mutable_data<T>(ctx_.GetPlace())));
#endif
} }
const CUDADeviceContext& ctx_; const CUDADeviceContext& ctx_;
const T coef_; const T coef_;
#ifdef PADDLE_WITH_HIP
const miopenActivationMode_t mode_;
#else
const cudnnActivationMode_t mode_; const cudnnActivationMode_t mode_;
#endif
}; };
template <typename T> template <typename T>
struct CudnnReluFunctor : public CudnnActivationFunctor<T> { struct CudnnReluFunctor : public CudnnActivationFunctor<T> {
explicit CudnnReluFunctor(const CUDADeviceContext& ctx) explicit CudnnReluFunctor(const CUDADeviceContext& ctx)
: CudnnActivationFunctor<T>(ctx, 0.0, CUDNN_ACTIVATION_RELU) {} : CudnnActivationFunctor<T>(ctx, 0.0, GPUDNN_ACTIVATION_RELU) {}
}; };
template <typename T> template <typename T>
struct CudnnReluGradFunctor : public CudnnActivationGradFunctor<T> { struct CudnnReluGradFunctor : public CudnnActivationGradFunctor<T> {
explicit CudnnReluGradFunctor(const CUDADeviceContext& ctx) explicit CudnnReluGradFunctor(const CUDADeviceContext& ctx)
: CudnnActivationGradFunctor<T>(ctx, 0.0, CUDNN_ACTIVATION_RELU) {} : CudnnActivationGradFunctor<T>(ctx, 0.0, GPUDNN_ACTIVATION_RELU) {}
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; } static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; }
}; };
...@@ -95,13 +148,13 @@ struct CudnnReluGradFunctor : public CudnnActivationGradFunctor<T> { ...@@ -95,13 +148,13 @@ struct CudnnReluGradFunctor : public CudnnActivationGradFunctor<T> {
template <typename T> template <typename T>
struct CudnnRelu6Functor : public CudnnActivationFunctor<T> { struct CudnnRelu6Functor : public CudnnActivationFunctor<T> {
explicit CudnnRelu6Functor(const CUDADeviceContext& ctx) explicit CudnnRelu6Functor(const CUDADeviceContext& ctx)
: CudnnActivationFunctor<T>(ctx, 6.0, CUDNN_ACTIVATION_CLIPPED_RELU) {} : CudnnActivationFunctor<T>(ctx, 6.0, GPUDNN_ACTIVATION_CLIPPED_RELU) {}
}; };
template <typename T> template <typename T>
struct CudnnRelu6GradFunctor : public CudnnActivationGradFunctor<T> { struct CudnnRelu6GradFunctor : public CudnnActivationGradFunctor<T> {
explicit CudnnRelu6GradFunctor(const CUDADeviceContext& ctx) explicit CudnnRelu6GradFunctor(const CUDADeviceContext& ctx)
: CudnnActivationGradFunctor<T>(ctx, 6.0, CUDNN_ACTIVATION_CLIPPED_RELU) { : CudnnActivationGradFunctor<T>(ctx, 6.0,
} GPUDNN_ACTIVATION_CLIPPED_RELU) {}
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; } static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; }
}; };
...@@ -109,12 +162,12 @@ struct CudnnRelu6GradFunctor : public CudnnActivationGradFunctor<T> { ...@@ -109,12 +162,12 @@ struct CudnnRelu6GradFunctor : public CudnnActivationGradFunctor<T> {
template <typename T> template <typename T>
struct CudnnSigmoidFunctor : public CudnnActivationFunctor<T> { struct CudnnSigmoidFunctor : public CudnnActivationFunctor<T> {
explicit CudnnSigmoidFunctor(const CUDADeviceContext& ctx) explicit CudnnSigmoidFunctor(const CUDADeviceContext& ctx)
: CudnnActivationFunctor<T>(ctx, 0.0, CUDNN_ACTIVATION_SIGMOID) {} : CudnnActivationFunctor<T>(ctx, 0.0, GPUDNN_ACTIVATION_SIGMOID) {}
}; };
template <typename T> template <typename T>
struct CudnnSigmoidGradFunctor : public CudnnActivationGradFunctor<T> { struct CudnnSigmoidGradFunctor : public CudnnActivationGradFunctor<T> {
explicit CudnnSigmoidGradFunctor(const CUDADeviceContext& ctx) explicit CudnnSigmoidGradFunctor(const CUDADeviceContext& ctx)
: CudnnActivationGradFunctor<T>(ctx, 0.0, CUDNN_ACTIVATION_SIGMOID) {} : CudnnActivationGradFunctor<T>(ctx, 0.0, GPUDNN_ACTIVATION_SIGMOID) {}
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; } static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; }
}; };
...@@ -122,12 +175,12 @@ struct CudnnSigmoidGradFunctor : public CudnnActivationGradFunctor<T> { ...@@ -122,12 +175,12 @@ struct CudnnSigmoidGradFunctor : public CudnnActivationGradFunctor<T> {
template <typename T> template <typename T>
struct CudnnTanhFunctor : public CudnnActivationFunctor<T> { struct CudnnTanhFunctor : public CudnnActivationFunctor<T> {
explicit CudnnTanhFunctor(const CUDADeviceContext& ctx) explicit CudnnTanhFunctor(const CUDADeviceContext& ctx)
: CudnnActivationFunctor<T>(ctx, 0.0, CUDNN_ACTIVATION_TANH) {} : CudnnActivationFunctor<T>(ctx, 0.0, GPUDNN_ACTIVATION_TANH) {}
}; };
template <typename T> template <typename T>
struct CudnnTanhGradFunctor : public CudnnActivationGradFunctor<T> { struct CudnnTanhGradFunctor : public CudnnActivationGradFunctor<T> {
explicit CudnnTanhGradFunctor(const CUDADeviceContext& ctx) explicit CudnnTanhGradFunctor(const CUDADeviceContext& ctx)
: CudnnActivationGradFunctor<T>(ctx, 0.0, CUDNN_ACTIVATION_TANH) {} : CudnnActivationGradFunctor<T>(ctx, 0.0, GPUDNN_ACTIVATION_TANH) {}
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; } static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; }
}; };
...@@ -183,6 +236,14 @@ namespace ops = paddle::operators; ...@@ -183,6 +236,14 @@ namespace ops = paddle::operators;
__macro(sigmoid, CudnnSigmoidFunctor, CudnnSigmoidGradFunctor); \ __macro(sigmoid, CudnnSigmoidFunctor, CudnnSigmoidGradFunctor); \
__macro(tanh, CudnnTanhFunctor, CudnnTanhGradFunctor) __macro(tanh, CudnnTanhFunctor, CudnnTanhGradFunctor)
#ifdef PADDLE_WITH_HIP
#define REGISTER_ACTIVATION_CUDNN_KERNEL(act_type, functor, grad_functor) \
REGISTER_OP_KERNEL(act_type, CUDNN, plat::CUDAPlace, \
ops::CudnnActivationKernel<ops::functor<float>>); \
REGISTER_OP_KERNEL( \
act_type##_grad, CUDNN, plat::CUDAPlace, \
ops::CudnnActivationGradKernel<ops::grad_functor<float>>);
#else
#define REGISTER_ACTIVATION_CUDNN_KERNEL(act_type, functor, grad_functor) \ #define REGISTER_ACTIVATION_CUDNN_KERNEL(act_type, functor, grad_functor) \
REGISTER_OP_KERNEL(act_type, CUDNN, plat::CUDAPlace, \ REGISTER_OP_KERNEL(act_type, CUDNN, plat::CUDAPlace, \
ops::CudnnActivationKernel<ops::functor<float>>, \ ops::CudnnActivationKernel<ops::functor<float>>, \
...@@ -191,5 +252,6 @@ namespace ops = paddle::operators; ...@@ -191,5 +252,6 @@ namespace ops = paddle::operators;
act_type##_grad, CUDNN, plat::CUDAPlace, \ act_type##_grad, CUDNN, plat::CUDAPlace, \
ops::CudnnActivationGradKernel<ops::grad_functor<float>>, \ ops::CudnnActivationGradKernel<ops::grad_functor<float>>, \
ops::CudnnActivationGradKernel<ops::grad_functor<double>>); ops::CudnnActivationGradKernel<ops::grad_functor<double>>);
#endif
FOR_EACH_CUDNN_OP_FUNCTOR(REGISTER_ACTIVATION_CUDNN_KERNEL); FOR_EACH_CUDNN_OP_FUNCTOR(REGISTER_ACTIVATION_CUDNN_KERNEL);
...@@ -24,9 +24,6 @@ limitations under the License. */ ...@@ -24,9 +24,6 @@ limitations under the License. */
#include "paddle/fluid/operators/common_infer_shape_functions.h" #include "paddle/fluid/operators/common_infer_shape_functions.h"
#include "paddle/fluid/operators/mkldnn/mkldnn_activation_op.h" #include "paddle/fluid/operators/mkldnn/mkldnn_activation_op.h"
#include "paddle/fluid/platform/port.h" #include "paddle/fluid/platform/port.h"
#ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/platform/cudnn_helper.h"
#endif
DECLARE_bool(use_mkldnn); DECLARE_bool(use_mkldnn);
......
...@@ -12,7 +12,15 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,7 +12,15 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#ifdef __NVCC__
#include "cub/cub.cuh" #include "cub/cub.cuh"
#endif
#ifdef __HIPCC__
#include <hipcub/hipcub.hpp>
namespace cub = hipcub;
#endif
#include "paddle/fluid/framework/data_layout.h" #include "paddle/fluid/framework/data_layout.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/cuda_primitives.h" #include "paddle/fluid/platform/cuda_primitives.h"
......
...@@ -12,6 +12,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,6 +12,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#ifndef PADDLE_WITH_HIP
// HIP not support cudnnSpatialTfGridGeneratorForward
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/cudnn_helper.h" #include "paddle/fluid/platform/cudnn_helper.h"
...@@ -121,3 +124,5 @@ REGISTER_OP_KERNEL(affine_grid, CUDNN, plat::CUDAPlace, ...@@ -121,3 +124,5 @@ REGISTER_OP_KERNEL(affine_grid, CUDNN, plat::CUDAPlace,
REGISTER_OP_KERNEL(affine_grid_grad, CUDNN, plat::CUDAPlace, REGISTER_OP_KERNEL(affine_grid_grad, CUDNN, plat::CUDAPlace,
paddle::operators::CUDNNAffineGridGradOpKernel<float>, paddle::operators::CUDNNAffineGridGradOpKernel<float>,
paddle::operators::CUDNNAffineGridGradOpKernel<double>); paddle::operators::CUDNNAffineGridGradOpKernel<double>);
#endif // not PADDLE_WITH_HIP
...@@ -21,6 +21,9 @@ limitations under the License. */ ...@@ -21,6 +21,9 @@ limitations under the License. */
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/platform/cudnn_helper.h" #include "paddle/fluid/platform/cudnn_helper.h"
#endif #endif
#ifdef PADDLE_WITH_HIP
#include "paddle/fluid/platform/miopen_helper.h"
#endif
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -109,7 +112,7 @@ class AffineGridOp : public framework::OperatorWithKernel { ...@@ -109,7 +112,7 @@ class AffineGridOp : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
framework::LibraryType library{framework::LibraryType::kPlain}; framework::LibraryType library{framework::LibraryType::kPlain};
#ifdef PADDLE_WITH_CUDA #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
if (platform::CanCUDNNBeUsed(ctx)) { if (platform::CanCUDNNBeUsed(ctx)) {
library = framework::LibraryType::kCUDNN; library = framework::LibraryType::kCUDNN;
} }
...@@ -226,7 +229,7 @@ class AffineGridOpGrad : public framework::OperatorWithKernel { ...@@ -226,7 +229,7 @@ class AffineGridOpGrad : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
framework::LibraryType library_{framework::LibraryType::kPlain}; framework::LibraryType library_{framework::LibraryType::kPlain};
#ifdef PADDLE_WITH_CUDA #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
if (platform::CanCUDNNBeUsed(ctx)) { if (platform::CanCUDNNBeUsed(ctx)) {
library_ = framework::LibraryType::kCUDNN; library_ = framework::LibraryType::kCUDNN;
} }
......
...@@ -12,7 +12,6 @@ ...@@ -12,7 +12,6 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include <cuda_runtime.h>
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/operators/allclose_op.h" #include "paddle/fluid/operators/allclose_op.h"
...@@ -67,7 +66,11 @@ struct AllcloseFunctor<platform::CUDADeviceContext, T> { ...@@ -67,7 +66,11 @@ struct AllcloseFunctor<platform::CUDADeviceContext, T> {
int block = 1024; int block = 1024;
int grid = (block - 1 + num) / block; int grid = (block - 1 + num) / block;
grid = (grid > block) ? block : grid; grid = (grid > block) ? block : grid;
#ifdef PADDLE_WITH_HIP
hipMemset(out_data, true, sizeof(bool));
#else
cudaMemset(out_data, true, sizeof(bool)); cudaMemset(out_data, true, sizeof(bool));
#endif
AllcloseCUDAKernel<T><<<grid, block, 0, dev_ctx.stream()>>>( AllcloseCUDAKernel<T><<<grid, block, 0, dev_ctx.stream()>>>(
in_data, other_data, rtol, atol, equal_nan, num, out_data); in_data, other_data, rtol, atol, equal_nan, num, out_data);
} }
......
...@@ -14,9 +14,15 @@ limitations under the License. */ ...@@ -14,9 +14,15 @@ limitations under the License. */
#pragma once #pragma once
#ifdef __NVCC__ #if defined(__NVCC__) || defined(__HIPCC__)
#include <cub/cub.cuh> #ifdef __NVCC__
#include "cub/cub.cuh"
#endif
#ifdef __HIPCC__
#include <hipcub/hipcub.hpp>
namespace cub = hipcub;
#endif
#include <limits> #include <limits>
#include <string> #include <string>
#include <typeinfo> #include <typeinfo>
......
...@@ -16,13 +16,28 @@ limitations under the License. */ ...@@ -16,13 +16,28 @@ limitations under the License. */
#include <thrust/execution_policy.h> #include <thrust/execution_policy.h>
#include <thrust/sequence.h> #include <thrust/sequence.h>
#include <thrust/sort.h> #include <thrust/sort.h>
#ifdef __NVCC__
#include "cub/cub.cuh" #include "cub/cub.cuh"
#endif
#ifdef __HIPCC__
#include <hipcub/hipcub.hpp>
namespace cub = hipcub;
#endif
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/argsort_op.h" #include "paddle/fluid/operators/argsort_op.h"
#include "paddle/fluid/operators/transpose_op.h" #include "paddle/fluid/operators/transpose_op.h"
#include "paddle/fluid/platform/cuda_device_function.h" #include "paddle/fluid/platform/cuda_device_function.h"
#include "paddle/fluid/platform/cuda_primitives.h" #include "paddle/fluid/platform/cuda_primitives.h"
#ifdef __HIPCC__
namespace rocprim {
namespace detail {
template <>
struct radix_key_codec_base<paddle::platform::float16>
: radix_key_codec_integral<paddle::platform::float16, uint16_t> {};
} // namespace detail
} // namespace rocprim
#else
// set cub base traits in order to handle float16 // set cub base traits in order to handle float16
namespace cub { namespace cub {
template <> template <>
...@@ -30,6 +45,7 @@ struct NumericTraits<paddle::platform::float16> ...@@ -30,6 +45,7 @@ struct NumericTraits<paddle::platform::float16>
: BaseTraits<FLOATING_POINT, true, false, uint16_t, : BaseTraits<FLOATING_POINT, true, false, uint16_t,
paddle::platform::float16> {}; paddle::platform::float16> {};
} // namespace cub } // namespace cub
#endif
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -139,7 +155,7 @@ void ArgFullSort(const platform::CUDADeviceContext& ctx, const Tensor* input, ...@@ -139,7 +155,7 @@ void ArgFullSort(const platform::CUDADeviceContext& ctx, const Tensor* input,
cub::CountingInputIterator<IndType>> cub::CountingInputIterator<IndType>>
segment_offsets_t(counting_iter, SegmentOffsetIter(num_cols)); segment_offsets_t(counting_iter, SegmentOffsetIter(num_cols));
cudaError_t err; gpuError_t err;
if (descending) { if (descending) {
err = cub::DeviceSegmentedRadixSort::SortPairsDescending( err = cub::DeviceSegmentedRadixSort::SortPairsDescending(
nullptr, temp_storage_bytes, inp, sorted_out_ptr, nullptr, temp_storage_bytes, inp, sorted_out_ptr,
......
...@@ -12,7 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,7 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include <cublas.h>
#include <string> #include <string>
#include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/operators/batch_fc_op.h" #include "paddle/fluid/operators/batch_fc_op.h"
...@@ -42,7 +41,7 @@ __global__ void add_bias_kernel(T* data, int slot_pairs_num, int ins_num, ...@@ -42,7 +41,7 @@ __global__ void add_bias_kernel(T* data, int slot_pairs_num, int ins_num,
} }
template <typename T> template <typename T>
void add_bias(cudaStream_t stream, T* data, int slot_pairs_num, int ins_num, void add_bias(gpuStream_t stream, T* data, int slot_pairs_num, int ins_num,
int out_dim, const T* bias) { int out_dim, const T* bias) {
add_bias_kernel<<<GET_BLOCKS(slot_pairs_num * ins_num * out_dim), add_bias_kernel<<<GET_BLOCKS(slot_pairs_num * ins_num * out_dim),
CUDA_NUM_THREADS, 0, stream>>>(data, slot_pairs_num, CUDA_NUM_THREADS, 0, stream>>>(data, slot_pairs_num,
...@@ -65,7 +64,7 @@ __global__ void add_bias_grad_kernel(const T* dout_data, int slot_pairs_num, ...@@ -65,7 +64,7 @@ __global__ void add_bias_grad_kernel(const T* dout_data, int slot_pairs_num,
} }
template <typename T> template <typename T>
void add_bias_grad(cudaStream_t stream, const T* dout_data, int slot_pairs_num, void add_bias_grad(gpuStream_t stream, const T* dout_data, int slot_pairs_num,
int ins_num, int out_dim, T* db_data) { int ins_num, int out_dim, T* db_data) {
add_bias_grad_kernel<<<GET_BLOCKS(slot_pairs_num * out_dim), CUDA_NUM_THREADS, add_bias_grad_kernel<<<GET_BLOCKS(slot_pairs_num * out_dim), CUDA_NUM_THREADS,
0, stream>>>(dout_data, slot_pairs_num, ins_num, 0, stream>>>(dout_data, slot_pairs_num, ins_num,
......
...@@ -16,12 +16,17 @@ limitations under the License. */ ...@@ -16,12 +16,17 @@ limitations under the License. */
#include <cfloat> #include <cfloat>
#include <string> #include <string>
#include <vector> #include <vector>
#ifdef __NVCC__
#include "cub/cub.cuh" #include "cub/cub.cuh"
#endif
#ifdef __HIPCC__
#include <hipcub/hipcub.hpp>
namespace cub = hipcub;
#endif
#include "paddle/fluid/framework/data_layout.h" #include "paddle/fluid/framework/data_layout.h"
#include "paddle/fluid/operators/batch_norm_op.h" #include "paddle/fluid/operators/batch_norm_op.h"
#include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/norm_utils.cu.h" #include "paddle/fluid/operators/norm_utils.cu.h"
#include "paddle/fluid/platform/cudnn_helper.h"
#include "paddle/fluid/platform/float16.h" #include "paddle/fluid/platform/float16.h"
DECLARE_bool(cudnn_batchnorm_spatial_persistent); DECLARE_bool(cudnn_batchnorm_spatial_persistent);
...@@ -73,6 +78,11 @@ class BatchNormKernel<platform::CUDADeviceContext, T> ...@@ -73,6 +78,11 @@ class BatchNormKernel<platform::CUDADeviceContext, T>
ExtractNCWHD(x_dims, data_layout, &N, &C, &H, &W, &D); ExtractNCWHD(x_dims, data_layout, &N, &C, &H, &W, &D);
auto dtype = platform::CudnnDataType<T>::type; auto dtype = platform::CudnnDataType<T>::type;
#ifdef PADDLE_WITH_HIP
// HIP do not support compute format of NHWC
auto compute_format = DataLayout::kNCHW;
#else
const bool fast_nhwc_batch_norm = const bool fast_nhwc_batch_norm =
test_mode || test_mode ||
(dtype == CUDNN_DATA_HALF && FLAGS_cudnn_batchnorm_spatial_persistent); (dtype == CUDNN_DATA_HALF && FLAGS_cudnn_batchnorm_spatial_persistent);
...@@ -81,6 +91,7 @@ class BatchNormKernel<platform::CUDADeviceContext, T> ...@@ -81,6 +91,7 @@ class BatchNormKernel<platform::CUDADeviceContext, T>
fast_nhwc_batch_norm && data_layout == DataLayout::kNHWC fast_nhwc_batch_norm && data_layout == DataLayout::kNHWC
? DataLayout::kNHWC ? DataLayout::kNHWC
: DataLayout::kNCHW; : DataLayout::kNCHW;
#endif
Tensor transformed_x(x->type()); Tensor transformed_x(x->type());
Tensor transformed_y(y->type()); Tensor transformed_y(y->type());
...@@ -98,7 +109,17 @@ class BatchNormKernel<platform::CUDADeviceContext, T> ...@@ -98,7 +109,17 @@ class BatchNormKernel<platform::CUDADeviceContext, T>
transformed_y.ShareDataWith(*y); transformed_y.ShareDataWith(*y);
} }
// ------------------- cudnn descriptors --------------------- // ------------------- cudnn descriptors ---------------------
#ifdef PADDLE_WITH_HIP
miopenTensorDescriptor_t data_desc_;
miopenTensorDescriptor_t bn_param_desc_;
miopenBatchNormMode_t mode_;
PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::miopenCreateTensorDescriptor(&data_desc_));
PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::miopenCreateTensorDescriptor(&bn_param_desc_));
#else
cudnnTensorDescriptor_t data_desc_; cudnnTensorDescriptor_t data_desc_;
cudnnTensorDescriptor_t bn_param_desc_; cudnnTensorDescriptor_t bn_param_desc_;
cudnnBatchNormMode_t mode_; cudnnBatchNormMode_t mode_;
...@@ -107,6 +128,7 @@ class BatchNormKernel<platform::CUDADeviceContext, T> ...@@ -107,6 +128,7 @@ class BatchNormKernel<platform::CUDADeviceContext, T>
platform::dynload::cudnnCreateTensorDescriptor(&data_desc_)); platform::dynload::cudnnCreateTensorDescriptor(&data_desc_));
PADDLE_ENFORCE_CUDA_SUCCESS( PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::cudnnCreateTensorDescriptor(&bn_param_desc_)); platform::dynload::cudnnCreateTensorDescriptor(&bn_param_desc_));
#endif
if (epsilon <= CUDNN_BN_MIN_EPSILON - FLT_EPSILON) { if (epsilon <= CUDNN_BN_MIN_EPSILON - FLT_EPSILON) {
LOG(ERROR) << "Provided epsilon is smaller than " LOG(ERROR) << "Provided epsilon is smaller than "
...@@ -114,7 +136,10 @@ class BatchNormKernel<platform::CUDADeviceContext, T> ...@@ -114,7 +136,10 @@ class BatchNormKernel<platform::CUDADeviceContext, T>
<< "CUDNN_BN_MIN_EPSILON instead."; << "CUDNN_BN_MIN_EPSILON instead.";
} }
epsilon = std::max(epsilon, CUDNN_BN_MIN_EPSILON); epsilon = std::max(epsilon, CUDNN_BN_MIN_EPSILON);
#if CUDNN_VERSION_MIN(7, 0, 1)
#ifdef PADDLE_WITH_HIP
mode_ = miopenBNSpatial;
#elif CUDNN_VERSION_MIN(7, 0, 1)
if (FLAGS_cudnn_batchnorm_spatial_persistent) { if (FLAGS_cudnn_batchnorm_spatial_persistent) {
mode_ = CUDNN_BATCHNORM_SPATIAL_PERSISTENT; mode_ = CUDNN_BATCHNORM_SPATIAL_PERSISTENT;
} else { } else {
...@@ -134,6 +159,17 @@ class BatchNormKernel<platform::CUDADeviceContext, T> ...@@ -134,6 +159,17 @@ class BatchNormKernel<platform::CUDADeviceContext, T>
dims = {N, C, H, W, D}; dims = {N, C, H, W, D};
strides = {H * W * D * C, 1, W * D * C, D * C, C}; strides = {H * W * D * C, 1, W * D * C, D * C, C};
} }
#ifdef PADDLE_WITH_HIP
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::miopenSetTensorDescriptor(
data_desc_, CudnnDataType<T>::type,
x_dims.size() > 3 ? x_dims.size() : 4, const_cast<int *>(dims.data()),
const_cast<int *>(strides.data())));
// Note: PERSISTENT not implemented for inference
PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::miopenDeriveBNTensorDescriptor(
bn_param_desc_, data_desc_, test_mode ? miopenBNSpatial : mode_));
#else
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnSetTensorNdDescriptor( PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnSetTensorNdDescriptor(
data_desc_, CudnnDataType<T>::type, data_desc_, CudnnDataType<T>::type,
x_dims.size() > 3 ? x_dims.size() : 4, dims.data(), strides.data())); x_dims.size() > 3 ? x_dims.size() : 4, dims.data(), strides.data()));
...@@ -142,6 +178,7 @@ class BatchNormKernel<platform::CUDADeviceContext, T> ...@@ -142,6 +178,7 @@ class BatchNormKernel<platform::CUDADeviceContext, T>
platform::dynload::cudnnDeriveBNTensorDescriptor( platform::dynload::cudnnDeriveBNTensorDescriptor(
bn_param_desc_, data_desc_, bn_param_desc_, data_desc_,
test_mode ? CUDNN_BATCHNORM_SPATIAL : mode_)); test_mode ? CUDNN_BATCHNORM_SPATIAL : mode_));
#endif
const auto *scale = ctx.Input<Tensor>("Scale"); const auto *scale = ctx.Input<Tensor>("Scale");
const auto *bias = ctx.Input<Tensor>("Bias"); const auto *bias = ctx.Input<Tensor>("Bias");
...@@ -188,6 +225,30 @@ class BatchNormKernel<platform::CUDADeviceContext, T> ...@@ -188,6 +225,30 @@ class BatchNormKernel<platform::CUDADeviceContext, T>
"variance is [%d], the dimensions of variance is [%s].", "variance is [%d], the dimensions of variance is [%s].",
C, est_var->dims()[0], est_var->dims())); C, est_var->dims()[0], est_var->dims()));
#ifdef PADDLE_WITH_HIP
PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::miopenBatchNormalizationForwardInference(
handle, miopenBNSpatial,
const_cast<void *>(
static_cast<const void *>(CudnnDataType<T>::kOne())),
const_cast<void *>(
static_cast<const void *>(CudnnDataType<T>::kZero())),
data_desc_,
static_cast<const void *>(transformed_x.template data<T>()),
data_desc_,
static_cast<void *>(
transformed_y.template mutable_data<T>(ctx.GetPlace())),
bn_param_desc_,
const_cast<void *>(static_cast<const void *>(
scale->template data<BatchNormParamType<T>>())),
const_cast<void *>(static_cast<const void *>(
bias->template data<BatchNormParamType<T>>())),
const_cast<void *>(static_cast<const void *>(
est_mean->template data<BatchNormParamType<T>>())),
const_cast<void *>(static_cast<const void *>(
est_var->template data<BatchNormParamType<T>>())),
epsilon));
#else
PADDLE_ENFORCE_CUDA_SUCCESS( PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::cudnnBatchNormalizationForwardInference( platform::dynload::cudnnBatchNormalizationForwardInference(
handle, handle,
...@@ -200,6 +261,7 @@ class BatchNormKernel<platform::CUDADeviceContext, T> ...@@ -200,6 +261,7 @@ class BatchNormKernel<platform::CUDADeviceContext, T>
bias->template data<BatchNormParamType<T>>(), bias->template data<BatchNormParamType<T>>(),
est_mean->template data<BatchNormParamType<T>>(), est_mean->template data<BatchNormParamType<T>>(),
est_var->template data<BatchNormParamType<T>>(), epsilon)); est_var->template data<BatchNormParamType<T>>(), epsilon));
#endif
} else { } else {
// if MomentumTensor is set, use MomentumTensor value, momentum // if MomentumTensor is set, use MomentumTensor value, momentum
// is only used in this training branch // is only used in this training branch
...@@ -302,6 +364,36 @@ class BatchNormKernel<platform::CUDADeviceContext, T> ...@@ -302,6 +364,36 @@ class BatchNormKernel<platform::CUDADeviceContext, T>
reserve_space_size)); reserve_space_size));
#endif // CUDNN_VERSION_MIN(7, 4, 1) #endif // CUDNN_VERSION_MIN(7, 4, 1)
if (!called) { if (!called) {
#ifdef PADDLE_WITH_HIP
PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::miopenBatchNormalizationForwardTraining(
handle, mode_, const_cast<void *>(static_cast<const void *>(
CudnnDataType<T>::kOne())),
const_cast<void *>(
static_cast<const void *>(CudnnDataType<T>::kZero())),
data_desc_,
static_cast<const void *>(transformed_x.template data<T>()),
data_desc_,
static_cast<void *>(
transformed_y.template mutable_data<T>(ctx.GetPlace())),
bn_param_desc_,
const_cast<void *>(static_cast<const void *>(
scale->template data<BatchNormParamType<T>>())),
const_cast<void *>(static_cast<const void *>(
bias->template data<BatchNormParamType<T>>())),
this_factor,
static_cast<void *>(
mean_out->template mutable_data<BatchNormParamType<T>>(
ctx.GetPlace())),
static_cast<void *>(variance_out->template mutable_data<
BatchNormParamType<T>>(ctx.GetPlace())),
epsilon,
static_cast<void *>(
saved_mean->template mutable_data<BatchNormParamType<T>>(
ctx.GetPlace())),
static_cast<void *>(saved_variance->template mutable_data<
BatchNormParamType<T>>(ctx.GetPlace()))));
#else
PADDLE_ENFORCE_CUDA_SUCCESS( PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::cudnnBatchNormalizationForwardTraining( platform::dynload::cudnnBatchNormalizationForwardTraining(
handle, mode_, CudnnDataType<T>::kOne(), handle, mode_, CudnnDataType<T>::kOne(),
...@@ -319,6 +411,7 @@ class BatchNormKernel<platform::CUDADeviceContext, T> ...@@ -319,6 +411,7 @@ class BatchNormKernel<platform::CUDADeviceContext, T>
ctx.GetPlace()), ctx.GetPlace()),
saved_variance->template mutable_data<BatchNormParamType<T>>( saved_variance->template mutable_data<BatchNormParamType<T>>(
ctx.GetPlace()))); ctx.GetPlace())));
#endif
} }
} }
} }
...@@ -329,11 +422,19 @@ class BatchNormKernel<platform::CUDADeviceContext, T> ...@@ -329,11 +422,19 @@ class BatchNormKernel<platform::CUDADeviceContext, T>
TransToChannelLast<paddle::platform::CUDADeviceContext, T>( TransToChannelLast<paddle::platform::CUDADeviceContext, T>(
ctx, &transformed_y, y); ctx, &transformed_y, y);
} }
#ifdef PADDLE_WITH_HIP
// clean when exit.
PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::miopenDestroyTensorDescriptor(data_desc_));
PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::miopenDestroyTensorDescriptor(bn_param_desc_));
#else
// clean when exit. // clean when exit.
PADDLE_ENFORCE_CUDA_SUCCESS( PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::cudnnDestroyTensorDescriptor(data_desc_)); platform::dynload::cudnnDestroyTensorDescriptor(data_desc_));
PADDLE_ENFORCE_CUDA_SUCCESS( PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::cudnnDestroyTensorDescriptor(bn_param_desc_)); platform::dynload::cudnnDestroyTensorDescriptor(bn_param_desc_));
#endif
} }
}; };
...@@ -416,7 +517,7 @@ class InplaceHelper { ...@@ -416,7 +517,7 @@ class InplaceHelper {
const BatchNormParamType<T> *mean, const BatchNormParamType<T> *mean,
const BatchNormParamType<T> *variance, double epsilon, int C, const BatchNormParamType<T> *variance, double epsilon, int C,
int M, const int num, const T *y, int grid2, const int block, int M, const int num, const T *y, int grid2, const int block,
const cudaStream_t &stream) { const gpuStream_t &stream) {
PADDLE_ENFORCE_EQ(x, y, platform::errors::InvalidArgument( PADDLE_ENFORCE_EQ(x, y, platform::errors::InvalidArgument(
"X and Y should be inplaced in inplace mode")); "X and Y should be inplaced in inplace mode"));
KeBNRestoreData<<<grid2, block, 0, stream>>>( KeBNRestoreData<<<grid2, block, 0, stream>>>(
...@@ -566,6 +667,10 @@ class BatchNormGradKernel<platform::CUDADeviceContext, T> ...@@ -566,6 +667,10 @@ class BatchNormGradKernel<platform::CUDADeviceContext, T>
auto dtype = platform::CudnnDataType<T>::type; auto dtype = platform::CudnnDataType<T>::type;
const auto *reserve_space = ctx.Input<Tensor>("ReserveSpace"); const auto *reserve_space = ctx.Input<Tensor>("ReserveSpace");
#ifdef PADDLE_WITH_HIP
// HIP do not support compute format of NHWC
auto compute_format = DataLayout::kNCHW;
#else
const bool fast_nhwc_batch_norm = const bool fast_nhwc_batch_norm =
dtype == CUDNN_DATA_HALF && FLAGS_cudnn_batchnorm_spatial_persistent && dtype == CUDNN_DATA_HALF && FLAGS_cudnn_batchnorm_spatial_persistent &&
reserve_space != nullptr; reserve_space != nullptr;
...@@ -573,6 +678,7 @@ class BatchNormGradKernel<platform::CUDADeviceContext, T> ...@@ -573,6 +678,7 @@ class BatchNormGradKernel<platform::CUDADeviceContext, T>
fast_nhwc_batch_norm && data_layout == DataLayout::kNHWC fast_nhwc_batch_norm && data_layout == DataLayout::kNHWC
? DataLayout::kNHWC ? DataLayout::kNHWC
: DataLayout::kNCHW; : DataLayout::kNCHW;
#endif
Tensor transformed_x(x->type()); Tensor transformed_x(x->type());
Tensor transformed_d_y(d_y->type()); Tensor transformed_d_y(d_y->type());
...@@ -626,7 +732,17 @@ class BatchNormGradKernel<platform::CUDADeviceContext, T> ...@@ -626,7 +732,17 @@ class BatchNormGradKernel<platform::CUDADeviceContext, T>
return; return;
} }
// ------------------- cudnn descriptors --------------------- // ------------------- cudnn descriptors ---------------------
#ifdef PADDLE_WITH_HIP
miopenTensorDescriptor_t data_desc_;
miopenTensorDescriptor_t bn_param_desc_;
miopenBatchNormMode_t mode_;
PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::miopenCreateTensorDescriptor(&data_desc_));
PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::miopenCreateTensorDescriptor(&bn_param_desc_));
#else
cudnnTensorDescriptor_t data_desc_; cudnnTensorDescriptor_t data_desc_;
cudnnTensorDescriptor_t bn_param_desc_; cudnnTensorDescriptor_t bn_param_desc_;
cudnnBatchNormMode_t mode_; cudnnBatchNormMode_t mode_;
...@@ -635,13 +751,16 @@ class BatchNormGradKernel<platform::CUDADeviceContext, T> ...@@ -635,13 +751,16 @@ class BatchNormGradKernel<platform::CUDADeviceContext, T>
platform::dynload::cudnnCreateTensorDescriptor(&data_desc_)); platform::dynload::cudnnCreateTensorDescriptor(&data_desc_));
PADDLE_ENFORCE_CUDA_SUCCESS( PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::cudnnCreateTensorDescriptor(&bn_param_desc_)); platform::dynload::cudnnCreateTensorDescriptor(&bn_param_desc_));
#endif
if (epsilon <= CUDNN_BN_MIN_EPSILON - FLT_EPSILON) { if (epsilon <= CUDNN_BN_MIN_EPSILON - FLT_EPSILON) {
LOG(ERROR) << "Provided epsilon is smaller than " LOG(ERROR) << "Provided epsilon is smaller than "
<< "CUDNN_BN_MIN_EPSILON. Setting it to " << "CUDNN_BN_MIN_EPSILON. Setting it to "
<< "CUDNN_BN_MIN_EPSILON instead."; << "CUDNN_BN_MIN_EPSILON instead.";
} }
epsilon = std::max(epsilon, CUDNN_BN_MIN_EPSILON); epsilon = std::max(epsilon, CUDNN_BN_MIN_EPSILON);
#if CUDNN_VERSION_MIN(7, 0, 1) #ifdef PADDLE_WITH_HIP
mode_ = miopenBNSpatial;
#elif CUDNN_VERSION_MIN(7, 0, 1)
if (FLAGS_cudnn_batchnorm_spatial_persistent) { if (FLAGS_cudnn_batchnorm_spatial_persistent) {
mode_ = CUDNN_BATCHNORM_SPATIAL_PERSISTENT; mode_ = CUDNN_BATCHNORM_SPATIAL_PERSISTENT;
} else { } else {
...@@ -651,12 +770,22 @@ class BatchNormGradKernel<platform::CUDADeviceContext, T> ...@@ -651,12 +770,22 @@ class BatchNormGradKernel<platform::CUDADeviceContext, T>
mode_ = CUDNN_BATCHNORM_SPATIAL; mode_ = CUDNN_BATCHNORM_SPATIAL;
#endif // CUDNN_VERSION_MIN(7, 0, 1) #endif // CUDNN_VERSION_MIN(7, 0, 1)
#ifdef PADDLE_WITH_HIP
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::miopenSetTensorDescriptor(
data_desc_, CudnnDataType<T>::type,
x_dims.size() > 3 ? x_dims.size() : 4, const_cast<int *>(dims.data()),
const_cast<int *>(strides.data())));
PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::miopenDeriveBNTensorDescriptor(bn_param_desc_,
data_desc_, mode_));
#else
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnSetTensorNdDescriptor( PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnSetTensorNdDescriptor(
data_desc_, CudnnDataType<T>::type, data_desc_, CudnnDataType<T>::type,
x_dims.size() > 3 ? x_dims.size() : 4, dims.data(), strides.data())); x_dims.size() > 3 ? x_dims.size() : 4, dims.data(), strides.data()));
PADDLE_ENFORCE_CUDA_SUCCESS( PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::cudnnDeriveBNTensorDescriptor(bn_param_desc_, platform::dynload::cudnnDeriveBNTensorDescriptor(bn_param_desc_,
data_desc_, mode_)); data_desc_, mode_));
#endif
const auto *saved_mean = ctx.Input<Tensor>("SavedMean"); const auto *saved_mean = ctx.Input<Tensor>("SavedMean");
const auto *saved_var = ctx.Input<Tensor>("SavedVariance"); const auto *saved_var = ctx.Input<Tensor>("SavedVariance");
...@@ -741,6 +870,22 @@ class BatchNormGradKernel<platform::CUDADeviceContext, T> ...@@ -741,6 +870,22 @@ class BatchNormGradKernel<platform::CUDADeviceContext, T>
/*reserveSpaceSizeInBytes=*/reserve_space_size)); /*reserveSpaceSizeInBytes=*/reserve_space_size));
#endif // CUDNN_VERSION_MIN(7, 4, 1) #endif // CUDNN_VERSION_MIN(7, 4, 1)
if (!called) { if (!called) {
#ifdef PADDLE_WITH_HIP
PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::miopenBatchNormalizationBackward(
dev_ctx.cudnn_handle(), mode_, CudnnDataType<T>::kOne(),
CudnnDataType<T>::kZero(), CudnnDataType<T>::kOne(),
CudnnDataType<T>::kZero(), data_desc_,
transformed_x.template data<T>(), data_desc_,
transformed_d_y.template data<T>(), data_desc_,
transformed_d_x.template mutable_data<T>(ctx.GetPlace()),
bn_param_desc_, scale->template data<BatchNormParamType<T>>(),
d_scale->template mutable_data<BatchNormParamType<T>>(
ctx.GetPlace()),
d_bias->template mutable_data<BatchNormParamType<T>>(
ctx.GetPlace()),
epsilon, saved_mean_data, saved_var_data));
#else
PADDLE_ENFORCE_CUDA_SUCCESS( PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::cudnnBatchNormalizationBackward( platform::dynload::cudnnBatchNormalizationBackward(
dev_ctx.cudnn_handle(), mode_, CudnnDataType<T>::kOne(), dev_ctx.cudnn_handle(), mode_, CudnnDataType<T>::kOne(),
...@@ -755,6 +900,7 @@ class BatchNormGradKernel<platform::CUDADeviceContext, T> ...@@ -755,6 +900,7 @@ class BatchNormGradKernel<platform::CUDADeviceContext, T>
d_bias->template mutable_data<BatchNormParamType<T>>( d_bias->template mutable_data<BatchNormParamType<T>>(
ctx.GetPlace()), ctx.GetPlace()),
epsilon, saved_mean_data, saved_var_data)); epsilon, saved_mean_data, saved_var_data));
#endif
} }
if (data_layout == DataLayout::kNHWC && if (data_layout == DataLayout::kNHWC &&
...@@ -784,11 +930,19 @@ class BatchNormGradKernel<platform::CUDADeviceContext, T> ...@@ -784,11 +930,19 @@ class BatchNormGradKernel<platform::CUDADeviceContext, T>
} }
} }
#ifdef PADDLE_WITH_HIP
// clean when exit.
PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::miopenDestroyTensorDescriptor(data_desc_));
PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::miopenDestroyTensorDescriptor(bn_param_desc_));
#else
// clean when exit. // clean when exit.
PADDLE_ENFORCE_CUDA_SUCCESS( PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::cudnnDestroyTensorDescriptor(data_desc_)); platform::dynload::cudnnDestroyTensorDescriptor(data_desc_));
PADDLE_ENFORCE_CUDA_SUCCESS( PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::cudnnDestroyTensorDescriptor(bn_param_desc_)); platform::dynload::cudnnDestroyTensorDescriptor(bn_param_desc_));
#endif
} else { } else {
const auto *running_mean = ctx.Input<Tensor>("Mean"); const auto *running_mean = ctx.Input<Tensor>("Mean");
const auto *running_var = ctx.Input<Tensor>("Variance"); const auto *running_var = ctx.Input<Tensor>("Variance");
...@@ -886,6 +1040,18 @@ class BatchNormDoubleGradKernel<platform::CUDADeviceContext, T> ...@@ -886,6 +1040,18 @@ class BatchNormDoubleGradKernel<platform::CUDADeviceContext, T>
namespace ops = paddle::operators; namespace ops = paddle::operators;
namespace plat = paddle::platform; namespace plat = paddle::platform;
#ifdef PADDLE_WITH_HIP
// MIOPEN do not support double
REGISTER_OP_CUDA_KERNEL(
batch_norm, ops::BatchNormKernel<plat::CUDADeviceContext, float>,
ops::BatchNormKernel<plat::CUDADeviceContext, plat::float16>);
REGISTER_OP_CUDA_KERNEL(
batch_norm_grad, ops::BatchNormGradKernel<plat::CUDADeviceContext, float>,
ops::BatchNormGradKernel<plat::CUDADeviceContext, plat::float16>);
REGISTER_OP_CUDA_KERNEL(
batch_norm_grad_grad,
ops::BatchNormDoubleGradKernel<plat::CUDADeviceContext, float>);
#else
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(
batch_norm, ops::BatchNormKernel<plat::CUDADeviceContext, float>, batch_norm, ops::BatchNormKernel<plat::CUDADeviceContext, float>,
ops::BatchNormKernel<plat::CUDADeviceContext, double>, ops::BatchNormKernel<plat::CUDADeviceContext, double>,
...@@ -898,3 +1064,4 @@ REGISTER_OP_CUDA_KERNEL( ...@@ -898,3 +1064,4 @@ REGISTER_OP_CUDA_KERNEL(
batch_norm_grad_grad, batch_norm_grad_grad,
ops::BatchNormDoubleGradKernel<plat::CUDADeviceContext, float>, ops::BatchNormDoubleGradKernel<plat::CUDADeviceContext, float>,
ops::BatchNormDoubleGradKernel<plat::CUDADeviceContext, double>); ops::BatchNormDoubleGradKernel<plat::CUDADeviceContext, double>);
#endif
...@@ -12,7 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,7 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include <algorithm> #include <algorithm>
#include "cub/cub.cuh"
#include "paddle/fluid/operators/bce_loss_op.h" #include "paddle/fluid/operators/bce_loss_op.h"
#include "paddle/fluid/operators/math.h" #include "paddle/fluid/operators/math.h"
#include "paddle/fluid/platform/cuda_primitives.h" #include "paddle/fluid/platform/cuda_primitives.h"
......
...@@ -105,7 +105,7 @@ TEST(Seq2BatchPadding, CPU) { ...@@ -105,7 +105,7 @@ TEST(Seq2BatchPadding, CPU) {
128); 128);
} }
#ifdef PADDLE_WITH_CUDA #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
TEST(SequencePadding, CUDA) { TEST(SequencePadding, CUDA) {
auto place = paddle::platform::CUDAPlace(0); auto place = paddle::platform::CUDAPlace(0);
auto *context = static_cast<paddle::platform::CUDADeviceContext *>( auto *context = static_cast<paddle::platform::CUDADeviceContext *>(
......
...@@ -123,7 +123,7 @@ TEST(SequencePoolingGrad, CPU_SUM) { ...@@ -123,7 +123,7 @@ TEST(SequencePoolingGrad, CPU_SUM) {
lod2, 128); lod2, 128);
} }
#ifdef PADDLE_WITH_CUDA #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
TEST(SequencePoolingGrad, CUDA_SUM) { TEST(SequencePoolingGrad, CUDA_SUM) {
auto place = paddle::platform::CUDAPlace(0); auto place = paddle::platform::CUDAPlace(0);
auto *context = static_cast<paddle::platform::CUDADeviceContext *>( auto *context = static_cast<paddle::platform::CUDADeviceContext *>(
......
...@@ -44,10 +44,18 @@ class ScaleLoDTensorFunctor<platform::CUDADeviceContext, T> { ...@@ -44,10 +44,18 @@ class ScaleLoDTensorFunctor<platform::CUDADeviceContext, T> {
framework::LoD abs_offset_lod = framework::ToAbsOffset(lod); framework::LoD abs_offset_lod = framework::ToAbsOffset(lod);
T* seq_data = seq->mutable_data<T>(context.GetPlace()); T* seq_data = seq->mutable_data<T>(context.GetPlace());
#ifdef PADDLE_WITH_HIP
hipLaunchKernelGGL(
HIP_KERNEL_NAME(SequenceScaleKernel<T, PADDLE_CUDA_NUM_THREADS>),
dim3(num_seq), dim3(PADDLE_CUDA_NUM_THREADS), 0, context.stream(),
seq_data, abs_offset_lod[level].CUDAMutableData(context.GetPlace()),
scales, seq_width);
#else
SequenceScaleKernel<T, PADDLE_CUDA_NUM_THREADS><<< SequenceScaleKernel<T, PADDLE_CUDA_NUM_THREADS><<<
num_seq, PADDLE_CUDA_NUM_THREADS, 0, context.stream()>>>( num_seq, PADDLE_CUDA_NUM_THREADS, 0, context.stream()>>>(
seq_data, abs_offset_lod[level].CUDAMutableData(context.GetPlace()), seq_data, abs_offset_lod[level].CUDAMutableData(context.GetPlace()),
scales, seq_width); scales, seq_width);
#endif
} }
}; };
......
...@@ -16,7 +16,11 @@ limitations under the License. */ ...@@ -16,7 +16,11 @@ limitations under the License. */
#include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/math/softmax.h" #include "paddle/fluid/operators/math/softmax.h"
#include "paddle/fluid/operators/math/softmax_impl.h" #include "paddle/fluid/operators/math/softmax_impl.h"
#ifdef PADDLE_WITH_HIP
#include "paddle/fluid/platform/miopen_helper.h"
#else
#include "paddle/fluid/platform/cudnn_helper.h" #include "paddle/fluid/platform/cudnn_helper.h"
#endif
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -45,6 +49,16 @@ void SoftmaxCUDNNFunctor<T>::operator()( ...@@ -45,6 +49,16 @@ void SoftmaxCUDNNFunctor<T>::operator()(
if (cudnn_tensor_dims.size() <= 2) { if (cudnn_tensor_dims.size() <= 2) {
cudnn_tensor_dims.resize(4, 1); cudnn_tensor_dims.resize(4, 1);
} }
#ifdef PADDLE_WITH_HIP
miopenTensorDescriptor_t cudnn_x_desc =
xDesc.descriptor<T>(layout, cudnn_tensor_dims);
miopenTensorDescriptor_t cudnn_y_desc =
xDesc.descriptor<T>(layout, cudnn_tensor_dims);
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::miopenSoftmaxForward(
context.cudnn_handle(), CudnnDataType<T>::kOne(), cudnn_x_desc,
X->data<T>(), CudnnDataType<T>::kZero(), cudnn_y_desc,
Y->mutable_data<T>(context.GetPlace())));
#else
cudnnTensorDescriptor_t cudnn_x_desc = cudnnTensorDescriptor_t cudnn_x_desc =
xDesc.descriptor<T>(layout, cudnn_tensor_dims); xDesc.descriptor<T>(layout, cudnn_tensor_dims);
cudnnTensorDescriptor_t cudnn_y_desc = cudnnTensorDescriptor_t cudnn_y_desc =
...@@ -54,6 +68,7 @@ void SoftmaxCUDNNFunctor<T>::operator()( ...@@ -54,6 +68,7 @@ void SoftmaxCUDNNFunctor<T>::operator()(
CUDNN_SOFTMAX_MODE_INSTANCE, CudnnDataType<T>::kOne(), cudnn_x_desc, CUDNN_SOFTMAX_MODE_INSTANCE, CudnnDataType<T>::kOne(), cudnn_x_desc,
X->data<T>(), CudnnDataType<T>::kZero(), cudnn_y_desc, X->data<T>(), CudnnDataType<T>::kZero(), cudnn_y_desc,
Y->mutable_data<T>(context.GetPlace()))); Y->mutable_data<T>(context.GetPlace())));
#endif
} }
template <typename T> template <typename T>
...@@ -74,6 +89,19 @@ void SoftmaxGradCUDNNFunctor<T>::operator()( ...@@ -74,6 +89,19 @@ void SoftmaxGradCUDNNFunctor<T>::operator()(
if (cudnn_tensor_dims.size() <= 2) { if (cudnn_tensor_dims.size() <= 2) {
cudnn_tensor_dims.resize(4, 1); cudnn_tensor_dims.resize(4, 1);
} }
#ifdef PADDLE_WITH_HIP
miopenTensorDescriptor_t cudnn_y_desc =
yDesc.descriptor<T>(layout, cudnn_tensor_dims);
miopenTensorDescriptor_t cudnn_xgrad_desc =
dxDesc.descriptor<T>(layout, cudnn_tensor_dims);
miopenTensorDescriptor_t cudnn_ygrad_desc =
dyDesc.descriptor<T>(layout, cudnn_tensor_dims);
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::miopenSoftmaxBackward(
context.cudnn_handle(), CudnnDataType<T>::kOne(), cudnn_y_desc,
Y->data<T>(), cudnn_ygrad_desc, YGrad->data<T>(),
CudnnDataType<T>::kZero(), cudnn_xgrad_desc,
XGrad->mutable_data<T>(context.GetPlace())));
#else
cudnnTensorDescriptor_t cudnn_y_desc = cudnnTensorDescriptor_t cudnn_y_desc =
yDesc.descriptor<T>(layout, cudnn_tensor_dims); yDesc.descriptor<T>(layout, cudnn_tensor_dims);
cudnnTensorDescriptor_t cudnn_xgrad_desc = cudnnTensorDescriptor_t cudnn_xgrad_desc =
...@@ -86,15 +114,20 @@ void SoftmaxGradCUDNNFunctor<T>::operator()( ...@@ -86,15 +114,20 @@ void SoftmaxGradCUDNNFunctor<T>::operator()(
Y->data<T>(), cudnn_ygrad_desc, YGrad->data<T>(), Y->data<T>(), cudnn_ygrad_desc, YGrad->data<T>(),
CudnnDataType<T>::kZero(), cudnn_xgrad_desc, CudnnDataType<T>::kZero(), cudnn_xgrad_desc,
XGrad->mutable_data<T>(context.GetPlace()))); XGrad->mutable_data<T>(context.GetPlace())));
#endif
} }
template class SoftmaxCUDNNFunctor<platform::float16>;
template class SoftmaxCUDNNFunctor<float>; template class SoftmaxCUDNNFunctor<float>;
template class SoftmaxCUDNNFunctor<double>; template class SoftmaxCUDNNFunctor<platform::float16>;
template class SoftmaxGradCUDNNFunctor<float>; template class SoftmaxGradCUDNNFunctor<float>;
template class SoftmaxGradCUDNNFunctor<double>;
template class SoftmaxGradCUDNNFunctor<platform::float16>; template class SoftmaxGradCUDNNFunctor<platform::float16>;
// MIOPEN do not support double
#ifndef PADDLE_WITH_HIP
template class SoftmaxCUDNNFunctor<double>;
template class SoftmaxGradCUDNNFunctor<double>;
#endif
template class SoftmaxFunctor<platform::CUDADeviceContext, platform::float16, template class SoftmaxFunctor<platform::CUDADeviceContext, platform::float16,
false>; false>;
template class SoftmaxFunctor<platform::CUDADeviceContext, platform::float16, template class SoftmaxFunctor<platform::CUDADeviceContext, platform::float16,
......
...@@ -35,7 +35,7 @@ class SoftmaxGradFunctor { ...@@ -35,7 +35,7 @@ class SoftmaxGradFunctor {
framework::Tensor* x_grad); framework::Tensor* x_grad);
}; };
#ifdef PADDLE_WITH_CUDA #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
template <typename T> template <typename T>
class SoftmaxCUDNNFunctor { class SoftmaxCUDNNFunctor {
public: public:
......
...@@ -22,7 +22,7 @@ limitations under the License. */ ...@@ -22,7 +22,7 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/math/pooling.h" #include "paddle/fluid/operators/math/pooling.h"
#ifdef __NVCC__ #if defined(__HIPCC__) || defined(__NVCC__)
#include "paddle/fluid/operators/reduce_ops/cub_reduce.h" #include "paddle/fluid/operators/reduce_ops/cub_reduce.h"
#endif #endif
......
...@@ -278,6 +278,9 @@ class OpTest(unittest.TestCase): ...@@ -278,6 +278,9 @@ class OpTest(unittest.TestCase):
def is_mkldnn_op_test(): def is_mkldnn_op_test():
return hasattr(cls, "use_mkldnn") and cls.use_mkldnn == True return hasattr(cls, "use_mkldnn") and cls.use_mkldnn == True
def is_rocm_op_test():
return core.is_compiled_with_rocm()
if not hasattr(cls, "op_type"): if not hasattr(cls, "op_type"):
raise AssertionError( raise AssertionError(
"This test do not have op_type in class attrs, " "This test do not have op_type in class attrs, "
...@@ -298,7 +301,8 @@ class OpTest(unittest.TestCase): ...@@ -298,7 +301,8 @@ class OpTest(unittest.TestCase):
and cls.op_type not in op_accuracy_white_list.NO_FP64_CHECK_GRAD_OP_LIST \ and cls.op_type not in op_accuracy_white_list.NO_FP64_CHECK_GRAD_OP_LIST \
and not hasattr(cls, 'exist_fp64_check_grad') \ and not hasattr(cls, 'exist_fp64_check_grad') \
and not is_xpu_op_test() \ and not is_xpu_op_test() \
and not is_mkldnn_op_test(): and not is_mkldnn_op_test() \
and not is_rocm_op_test():
raise AssertionError( raise AssertionError(
"This test of %s op needs check_grad with fp64 precision." % "This test of %s op needs check_grad with fp64 precision." %
cls.op_type) cls.op_type)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册