“9e11ca8096a57cda6d91741c064b362180ff2a50”上不存在“paddle/legacy/utils/arch/osx/Locks.cpp”
未验证 提交 84639b61 编写于 作者: Q Qi Li 提交者: GitHub

[ROCM] update fluid operators for rocm (part3), test=develop (#31213)

* [ROCM] update fluid operators for rocm (part3), test=develop

* fix clang format error, test=develop
上级 3b9db171
......@@ -24,22 +24,28 @@ file(APPEND ${pybind_file} "USE_CPU_ONLY_OP(fusion_gru);\n")
file(APPEND ${pybind_file} "USE_CPU_ONLY_OP(fusion_lstm);\n")
if (WITH_GPU)
if (WITH_GPU OR WITH_ROCM)
# fused_bn_activation_op needs cudnn 7.4.1 above
if (NOT ${CUDNN_VERSION} VERSION_LESS 7401)
# HIP not support bn act fuse in MIOPEN
if ((NOT WITH_ROCM) AND (NOT ${CUDNN_VERSION} VERSION_LESS 7401))
op_library(fused_bn_activation_op)
file(APPEND ${pybind_file} "USE_CUDA_ONLY_OP(fused_batch_norm_act);\n")
endif()
# conv_fusion_op needs cudnn 7 above
if (NOT ${CUDNN_VERSION} VERSION_LESS 7100)
# HIP not support cudnnConvolutionBiasActivationForward
if ((NOT WITH_ROCM) AND (NOT ${CUDNN_VERSION} VERSION_LESS 7100))
op_library(conv_fusion_op)
file(APPEND ${pybind_file} "USE_CUDA_ONLY_OP(conv2d_fusion);\n")
endif()
# fusion_transpose_flatten_concat_op
op_library(fusion_transpose_flatten_concat_op)
file(APPEND ${pybind_file} "USE_CUDA_ONLY_OP(fusion_transpose_flatten_concat);\n")
# HIP not support cudnnTransformTensor
if(NOT WITH_ROCM)
op_library(fusion_transpose_flatten_concat_op)
file(APPEND ${pybind_file} "USE_CUDA_ONLY_OP(fusion_transpose_flatten_concat);\n")
endif()
# fusion_conv_inception_op needs cudnn 7 above
if (NOT ${CUDNN_VERSION} VERSION_LESS 7100)
# HIP not support cudnnConvolutionBiasActivationForward
if ((NOT WITH_ROCM) AND (NOT ${CUDNN_VERSION} VERSION_LESS 7100))
op_library(fusion_conv_inception_op)
file(APPEND ${pybind_file} "USE_CUDA_ONLY_OP(conv2d_inception_fusion);\n")
endif()
......@@ -60,8 +66,9 @@ if (WITH_GPU)
cc_test(test_fusion_group_op SRCS fusion_group_op_test.cc DEPS fusion_group_op)
endif()
# fused_bn_add_activation
if (NOT ${CUDNN_VERSION} VERSION_LESS 7401)
op_library(fused_bn_add_activation_op)
file(APPEND ${pybind_file} "USE_CUDA_ONLY_OP(fused_bn_add_activation);\n")
# HIP not support bn act fuse in MIOPEN
if ((NOT WITH_ROCM) AND (NOT ${CUDNN_VERSION} VERSION_LESS 7401))
op_library(fused_bn_add_activation_op)
file(APPEND ${pybind_file} "USE_CUDA_ONLY_OP(fused_bn_add_activation);\n")
endif()
endif()
......@@ -12,10 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include <cuda_runtime.h>
#include <paddle/fluid/platform/device_context.h>
#include <algorithm>
#include <cub/cub.cuh> // NOLINT
#include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/memory/malloc.h"
......@@ -39,7 +37,11 @@ class EmbeddingEltWiseLayerNormKernel : public framework::OpKernel<T> {
in_embs_(framework::proto::VarType::INT64);
framework::DDim in_dim{input_num};
int device_id;
#ifdef PADDLE_WITH_HIP
hipGetDevice(&device_id);
#else
cudaGetDevice(&device_id);
#endif
in_ids_.Resize(in_dim);
in_embs_.Resize(in_dim);
int64_t *in_ids_d =
......@@ -52,11 +54,17 @@ class EmbeddingEltWiseLayerNormKernel : public framework::OpKernel<T> {
in1s.push_back(reinterpret_cast<uintptr_t>(ids[i]->data<int64_t>()));
in2s.push_back(reinterpret_cast<uintptr_t>(embs[i]->data<T>()));
}
#ifdef PADDLE_WITH_HIP
hipMemcpyAsync(in_ids_d, in1s.data(), sizeof(int64_t) * input_num,
hipMemcpyHostToDevice, device_ctx.stream());
hipMemcpyAsync(in_embs_d, in2s.data(), sizeof(int64_t) * input_num,
hipMemcpyHostToDevice, device_ctx.stream());
#else
cudaMemcpyAsync(in_ids_d, in1s.data(), sizeof(int64_t) * input_num,
cudaMemcpyHostToDevice, device_ctx.stream());
cudaMemcpyAsync(in_embs_d, in2s.data(), sizeof(int64_t) * input_num,
cudaMemcpyHostToDevice, device_ctx.stream());
#endif
auto *bias = context.Input<framework::Tensor>("Bias");
auto *scale = context.Input<framework::Tensor>("Scale");
......
......@@ -12,7 +12,14 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef __NVCC__
#include <cub/cub.cuh>
#endif
#ifdef __HIPCC__
#include <hipcub/hipcub.hpp>
namespace cub = hipcub;
#endif
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/platform/cuda_device_function.h"
......
......@@ -12,7 +12,6 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include <cuda_runtime.h>
#include <paddle/fluid/platform/device_context.h>
#include <algorithm>
#include "paddle/fluid/framework/op_registry.h"
......@@ -89,7 +88,7 @@ __global__ void TransposeQkvKernel(const int H, const T *input, const T *bias,
void TransQKVWithBias(const int batch, const int seq_len, const int head_size,
const int head_num, const float *input, const float *bias,
float *output, cudaStream_t stream) {
float *output, gpuStream_t stream) {
// BxSx3xNxH + 3xNxH -> 3xBxNxSxH
int scratch_size = batch * head_num * seq_len * seq_len;
const dim3 grid(seq_len, batch, 3);
......
......@@ -12,7 +12,6 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include <cuda_runtime.h>
#include <paddle/fluid/platform/device_context.h>
#include <algorithm>
#include "paddle/fluid/framework/op_registry.h"
......
......@@ -83,7 +83,7 @@ class LiteEngineOp : public framework::OperatorBase {
<< engine_->GetInputNames()[i] << ")";
inference::lite::utils::TensorCopy(&dst_t, &src_t, *ctx, zero_copy_);
}
#ifdef PADDLE_WITH_CUDA
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
if (platform::is_gpu_place(dev_place)) {
platform::GpuStreamSync(
static_cast<const platform::CUDADeviceContext *>(ctx)->stream());
......@@ -101,7 +101,7 @@ class LiteEngineOp : public framework::OperatorBase {
<< engine_->GetOutputNames()[i] << ")";
inference::lite::utils::TensorCopy(dst_t, &src_t, *ctx, zero_copy_);
}
#ifdef PADDLE_WITH_CUDA
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
if (platform::is_gpu_place(dev_place)) {
platform::GpuStreamSync(
static_cast<const platform::CUDADeviceContext *>(ctx)->stream());
......
......@@ -67,7 +67,7 @@ TEST(LiteEngineOp, engine_op) {
*block_->add_ops() = *elt_add->Proto();
*block_->add_ops() = *fetch->Proto();
framework::Scope scope;
#ifdef PADDLE_WITH_CUDA
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
platform::CUDAPlace place;
platform::CUDADeviceContext ctx(place);
#else
......@@ -84,11 +84,11 @@ TEST(LiteEngineOp, engine_op) {
std::vector<std::string> repetitive_params{"x", "y"};
inference::lite::EngineConfig config;
config.valid_places = {
#ifdef PADDLE_WITH_CUDA
paddle::lite_api::Place({TARGET(kCUDA), PRECISION(kFloat)}),
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
paddle::lite_api::Place({TARGET(kCUDA), PRECISION(kFloat)}),
#endif
paddle::lite_api::Place({TARGET(kX86), PRECISION(kFloat)}),
paddle::lite_api::Place({TARGET(kHost), PRECISION(kAny)}),
paddle::lite_api::Place({TARGET(kX86), PRECISION(kFloat)}),
paddle::lite_api::Place({TARGET(kHost), PRECISION(kAny)}),
};
serialize_params(&(config.param), &scope, repetitive_params);
config.model = program.Proto()->SerializeAsString();
......
......@@ -55,7 +55,7 @@ void AddFetchListToBlockDesc(framework::proto::BlockDesc* block,
void serialize_params(std::string* str, framework::Scope* scope,
const std::vector<std::string>& params) {
std::ostringstream os;
#ifdef PADDLE_WITH_CUDA
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
platform::CUDAPlace place;
platform::CUDADeviceContext ctx(place);
#else
......@@ -106,7 +106,7 @@ void CreateTensor(framework::Scope* scope, const std::string& name,
tensor->Resize(dims);
platform::Place place;
if (in_cuda) {
#ifdef PADDLE_WITH_CUDA
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
place = platform::CUDAPlace(0);
#else
PADDLE_THROW(platform::errors::PreconditionNotMet(
......
......@@ -41,7 +41,7 @@ HOSTDEVICE inline int64_t BinarySearch(const T *x, int64_t num, const T &val) {
template <typename T>
HOSTDEVICE inline size_t LowerBound(const T *x, size_t num, const T &val) {
#ifdef __CUDA_ARCH__
#if defined(__CUDA_ARCH__) || defined(__HIPCC__) // @{ Group LowerBound
// The following code is from
// https://en.cppreference.com/w/cpp/algorithm/lower_bound
auto *first = x;
......@@ -59,12 +59,12 @@ HOSTDEVICE inline size_t LowerBound(const T *x, size_t num, const T &val) {
return static_cast<size_t>(first - x);
#else
return static_cast<size_t>(std::lower_bound(x, x + num, val) - x);
#endif
#endif // @} End Group LowerBound
}
template <typename T>
HOSTDEVICE inline size_t UpperBound(const T *x, size_t num, const T &val) {
#ifdef __CUDA_ARCH__
#if defined(__CUDA_ARCH__) || defined(__HIPCC__) // @{ Group UpperBound
// The following code is from
// https://en.cppreference.com/w/cpp/algorithm/upper_bound
auto *first = x;
......@@ -82,7 +82,7 @@ HOSTDEVICE inline size_t UpperBound(const T *x, size_t num, const T &val) {
return static_cast<size_t>(first - x);
#else
return static_cast<size_t>(std::upper_bound(x, x + num, val) - x);
#endif
#endif // @} End Group UpperBound
}
} // namespace math
......
......@@ -134,7 +134,7 @@ TEST(BeamSearch, CPU) {
paddle::platform::CPUPlace>();
}
#ifdef PADDLE_WITH_CUDA
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
TEST(BeamSearch, GPU) {
TestBeamSearch<paddle::platform::CUDADeviceContext,
paddle::platform::CUDAPlace>();
......
......@@ -102,7 +102,7 @@ class Blas {
T alpha, const T* A, int lda, const T* B, int ldb, T beta, T* C,
int ldc) const;
#ifdef PADDLE_WITH_MKLML
#ifdef PADDLE_WITH_MKLML // @{ Group MKLML: class Blas
template <typename T>
T* GEMM_ALLOC(const CBLAS_IDENTIFIER id, const int M, const int N,
const int K) const;
......@@ -126,7 +126,7 @@ class Blas {
const int* indx, const int* pntrb, const int* pntre, const T* b,
const int* ldb, const T* beta, T* c, const int* ldc) const;
#if !defined(PADDLE_WITH_CUDA)
#if !defined(PADDLE_WITH_CUDA) && !defined(PADDLE_WITH_HIP)
template <typename T>
void MatMulWithHead(const framework::Tensor& mat_a,
const MatDescriptor& dim_a,
......@@ -135,7 +135,7 @@ class Blas {
framework::Tensor* mat_out, T beta,
bool mat_y_split_vertical) const;
#endif
#endif
#endif // @} End Group MKLML: class Blas
template <typename T>
void MatMul(const int M, const int N, const int K, const T* A, const T* B,
......@@ -210,7 +210,8 @@ class Blas {
int K, T alpha, const T** A, const T** B, T beta, T** C,
int batchCount) const;
#if defined(PADDLE_WITH_MKLML) && !defined(PADDLE_WITH_CUDA)
#if defined(PADDLE_WITH_MKLML) && !defined(PADDLE_WITH_CUDA) && \
!defined(PADDLE_WITH_HIP)
template <typename T>
void BatchedGEMMWithHead(CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB,
int W1, int H1, int W2, int H2, T alpha, const T* A,
......@@ -235,7 +236,7 @@ class Blas {
CBLAS_DIAG diag, int M, int N, T alpha, const T* A, int lda, T* B,
int ldb) const;
#ifdef PADDLE_WITH_CUDA
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
template <typename T>
void BatchedGETRF(int n, T** a, int* ipiv, int* info, int batch_size) const;
......@@ -262,7 +263,7 @@ class BlasT : private Blas<DeviceContext> {
Base()->template GEMM<T>(args...);
}
#ifdef PADDLE_WITH_MKLML
#ifdef PADDLE_WITH_MKLML // @{ Group MKLML: class BlasT
template <typename... ARGS>
T* GEMM_ALLOC(ARGS... args) const {
return Base()->template GEMM_ALLOC<T>(args...);
......@@ -288,13 +289,13 @@ class BlasT : private Blas<DeviceContext> {
Base()->template CSRMM<T>(args...);
}
#if !defined(PADDLE_WITH_CUDA)
#if !defined(PADDLE_WITH_CUDA) && !defined(PADDLE_WITH_HIP)
template <typename... ARGS>
void MatMulWithHead(ARGS... args) const {
Base()->template MatMulWithHead<T>(args...);
}
#endif
#endif
#endif // @} End Group MKLML: class BlasT
template <typename... ARGS>
void MatMul(ARGS... args) const {
......@@ -386,7 +387,7 @@ class BlasT : private Blas<DeviceContext> {
Base()->template TRSM<T>(args...);
}
#ifdef PADDLE_WITH_CUDA
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
template <typename... ARGS>
void BatchedGETRF(ARGS... args) const {
Base()->template BatchedGETRF<T>(args...);
......@@ -429,3 +430,6 @@ inline BlasT<DeviceContext, T> GetBlas(const DeviceContext& dev_ctx) {
#ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/operators/math/blas_impl.cu.h"
#endif
#ifdef PADDLE_WITH_HIP
#include "paddle/fluid/operators/math/blas_impl.hip.h"
#endif
......@@ -1046,7 +1046,8 @@ void Blas<platform::CPUDeviceContext>::BatchedGEMM(
#endif
}
#if defined(PADDLE_WITH_MKLML) && !defined(PADDLE_WITH_CUDA)
#if defined(PADDLE_WITH_MKLML) && !defined(PADDLE_WITH_CUDA) && \
!defined(PADDLE_WITH_HIP) // @{ Group Blas MKLML: BatchedGEMMWithHead
template <>
template <typename T>
void Blas<platform::CPUDeviceContext>::BatchedGEMMWithHead(
......@@ -1116,7 +1117,7 @@ void Blas<platform::CPUDeviceContext>::BatchedGEMMWithHead(
}
}
}
#endif
#endif // @} End Group Blas MKLML: BatchedGEMMWithHead
template <typename DeviceContext>
template <typename T>
......@@ -1192,7 +1193,9 @@ void Blas<DeviceContext>::MatMul(const framework::Tensor &mat_a,
}
}
#if defined(PADDLE_WITH_MKLML) && !defined(PADDLE_WITH_CUDA)
#if defined(PADDLE_WITH_MKLML) && !defined(PADDLE_WITH_CUDA) && \
!defined(PADDLE_WITH_HIP)
// @{ Group Blas MKLML: MatMulWithHead
/*
* Multiple two matrixes with multiple heads
*
......@@ -1319,7 +1322,7 @@ void Blas<DeviceContext>::MatMulWithHead(const framework::Tensor &mat_a,
dim_a.stride_, dim_b.stride_, head_number, mat_b_split_vertical);
}
}
#endif
#endif // @} End Group Blas MKLML: MatMulWithHead
template <typename DeviceContext>
template <typename T>
......
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/platform/dynload/rocblas.h"
#include "paddle/fluid/platform/gpu_info.h"
DECLARE_bool(enable_cublas_tensor_op_math);
namespace paddle {
namespace operators {
namespace math {
template <typename T>
struct CUBlas;
template <>
struct CUBlas<float> {
template <typename... ARGS>
static void GEMM(ARGS... args) {
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::rocblas_sgemm(args...));
}
template <typename... ARGS>
static void AXPY(ARGS... args) {
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::rocblas_saxpy(args...));
}
template <typename... ARGS>
static void SCAL(ARGS... args) {
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::rocblas_sscal(args...));
}
template <typename... ARGS>
static void VCOPY(ARGS... args) {
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::rocblas_scopy(args...));
}
template <typename... ARGS>
static void GEMV(ARGS... args) {
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::rocblas_sgemv(args...));
}
template <typename... ARGS>
static void GEMM_STRIDED_BATCH(ARGS... args) {
PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::rocblas_sgemm_strided_batched(args...));
}
// HIP not supportted, refer to the doc here:
// https://github.com/ROCm-Developer-Tools/HIP/blob/roc-3.5.x/docs/markdown/CUBLAS_API_supported_by_HIP.md
template <typename... ARGS>
static void GEMM_EX(ARGS... args) {
PADDLE_THROW(platform::errors::Unimplemented(
"cublasSgemmEx is not supported on HIP platform."));
}
template <typename... ARGS>
static void TRSM(ARGS... args) {
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::rocblas_strsm(args...));
}
template <typename... ARGS>
static void GETRF_BATCH(ARGS... args) {
PADDLE_THROW(platform::errors::Unimplemented(
"cublasSgetrfBatched is not supported on HIP platform."));
}
template <typename... ARGS>
static void GETRI_BATCH(ARGS... args) {
PADDLE_THROW(platform::errors::Unimplemented(
"cublasSgetriBatched is not supported on HIP platform."));
}
template <typename... ARGS>
static void MATINV_BATCH(ARGS... args) {
PADDLE_THROW(platform::errors::Unimplemented(
"cublasSmatinvBatched is not supported on HIP platform."));
}
};
template <>
struct CUBlas<double> {
template <typename... ARGS>
static void GEMM(ARGS... args) {
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::rocblas_dgemm(args...));
}
template <typename... ARGS>
static void AXPY(ARGS... args) {
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::rocblas_daxpy(args...));
}
template <typename... ARGS>
static void SCAL(ARGS... args) {
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::rocblas_dscal(args...));
}
template <typename... ARGS>
static void VCOPY(ARGS... args) {
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::rocblas_dcopy(args...));
}
template <typename... ARGS>
static void GEMV(ARGS... args) {
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::rocblas_dgemv(args...));
}
template <typename... ARGS>
static void GEMM_STRIDED_BATCH(ARGS... args) {
PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::rocblas_dgemm_strided_batched(args...));
}
template <typename... ARGS>
static void GEMM_EX(ARGS... args) {
PADDLE_THROW(platform::errors::Unimplemented(
"Currently there are not cublasDgemmEx."));
}
template <typename... ARGS>
static void TRSM(ARGS... args) {
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::rocblas_dtrsm(args...));
}
template <typename... ARGS>
static void GETRF_BATCH(ARGS... args) {
PADDLE_THROW(platform::errors::Unimplemented(
"cublasDgetrfBatched is not supported on HIP platform."));
}
template <typename... ARGS>
static void GETRI_BATCH(ARGS... args) {
PADDLE_THROW(platform::errors::Unimplemented(
"cublasDgetriBatched is not supported on HIP platform."));
}
template <typename... ARGS>
static void MATINV_BATCH(ARGS... args) {
PADDLE_THROW(platform::errors::Unimplemented(
"cublasDmatinvBatched is not supported on HIP platform."));
}
};
template <>
struct CUBlas<platform::float16> {
using float16 = platform::float16;
static void GEMM(rocblas_handle handle, rocblas_operation transa,
rocblas_operation transb, int m, int n, int k,
const float16 *alpha, const float16 *A, int lda,
const float16 *B, int ldb, const float16 *beta, float16 *C,
int ldc) {
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::rocblas_hgemm(
handle, transa, transb, m, n, k,
reinterpret_cast<const rocblas_half *>(alpha),
reinterpret_cast<const rocblas_half *>(A), lda,
reinterpret_cast<const rocblas_half *>(B), ldb,
reinterpret_cast<const rocblas_half *>(beta),
reinterpret_cast<rocblas_half *>(C), ldc));
}
static void GEMM_STRIDED_BATCH(rocblas_handle handle,
rocblas_operation transa,
rocblas_operation transb, int m, int n, int k,
const float16 *alpha, const float16 *A,
int lda, long long int strideA, // NOLINT
const float16 *B, // NOLINT
int ldb, long long int strideB, // NOLINT
const float16 *beta, float16 *C, int ldc,
long long int strideC, // NOLINT
int batchCount) {
PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::rocblas_hgemm_strided_batched(
handle, transa, transb, m, n, k,
reinterpret_cast<const rocblas_half *>(alpha),
reinterpret_cast<const rocblas_half *>(A), lda, strideA,
reinterpret_cast<const rocblas_half *>(B), ldb, strideB,
reinterpret_cast<const rocblas_half *>(beta),
reinterpret_cast<rocblas_half *>(C), ldc, strideC, batchCount));
}
// NOTES: GEMM_EX can use Tensor Core to accelerate matrix multiply.
// https://docs.nvidia.com/cuda/cublas/index.html#cublassetmathmode
template <typename... ARGS>
static void GEMM_EX(platform::CUDADeviceContext *dev_ctx,
rocblas_operation transa, rocblas_operation transb, int m,
int n, int k, const void *alpha, const void *A,
rocblas_datatype Atype, int lda, const void *B,
rocblas_datatype Btype, int ldb, const void *beta,
void *C, rocblas_datatype Ctype, int ldc,
rocblas_datatype computeType) {
rocblas_gemm_algo algo = rocblas_gemm_algo_standard;
dev_ctx->TensorCoreCublasCallIfAvailable([&](rocblas_handle handle) {
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::rocblas_gemm_ex(
handle, transa, transb, m, n, k, alpha, A, Atype, lda, B, Btype, ldb,
beta, C, Ctype, ldc, C, Ctype, ldc, computeType, algo, 0, 0));
});
}
};
template <>
struct CUBlas<platform::complex64> {
using complex64 = platform::complex64;
static void GEMV(rocblas_handle handle, rocblas_operation transa, int m,
int n, const complex64 *alpha, const complex64 *A, int lda,
const complex64 *B, int ldb, const complex64 *beta,
complex64 *C, int ldc) {
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::rocblas_cgemv(
handle, transa, m, n,
reinterpret_cast<const rocblas_float_complex *>(alpha),
reinterpret_cast<const rocblas_float_complex *>(A), lda,
reinterpret_cast<const rocblas_float_complex *>(B), ldb,
reinterpret_cast<const rocblas_float_complex *>(beta),
reinterpret_cast<rocblas_float_complex *>(C), ldc));
}
static void AXPY(rocblas_handle handle, int n, const complex64 *alpha,
const complex64 *X, const int incX, complex64 *Y,
const int incY) {
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::rocblas_caxpy(
handle, n, reinterpret_cast<const rocblas_float_complex *>(alpha),
reinterpret_cast<const rocblas_float_complex *>(X), incX,
reinterpret_cast<rocblas_float_complex *>(Y), incY));
}
static void GEMM_STRIDED_BATCH(rocblas_handle handle,
rocblas_operation transa,
rocblas_operation transb, int m, int n, int k,
const complex64 *alpha, const complex64 *A,
int lda, long long int strideA, // NOLINT
const complex64 *B, // NOLINT
int ldb, long long int strideB, // NOLINT
const complex64 *beta, complex64 *C, int ldc,
long long int strideC, // NOLINT
int batchCount) {
PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::rocblas_cgemm_strided_batched(
handle, transa, transb, m, n, k,
reinterpret_cast<const rocblas_float_complex *>(alpha),
reinterpret_cast<const rocblas_float_complex *>(A), lda, strideA,
reinterpret_cast<const rocblas_float_complex *>(B), ldb, strideB,
reinterpret_cast<const rocblas_float_complex *>(beta),
reinterpret_cast<rocblas_float_complex *>(C), ldc, strideC,
batchCount));
}
static void GEMM(rocblas_handle handle, rocblas_operation transa,
rocblas_operation transb, int m, int n, int k,
const complex64 *alpha, const complex64 *A, int lda,
const complex64 *B, int ldb, const complex64 *beta,
complex64 *C, int ldc) {
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::rocblas_cgemm(
handle, transa, transb, m, n, k,
reinterpret_cast<const rocblas_float_complex *>(alpha),
reinterpret_cast<const rocblas_float_complex *>(A), lda,
reinterpret_cast<const rocblas_float_complex *>(B), ldb,
reinterpret_cast<const rocblas_float_complex *>(beta),
reinterpret_cast<rocblas_float_complex *>(C), ldc));
}
// NOTES: GEMM_EX can use Tensor Core to accelerate matrix multiply.
// https://docs.nvidia.com/cuda/cublas/index.html#cublassetmathmode
template <typename... ARGS>
static void GEMM_EX(platform::CUDADeviceContext *dev_ctx,
rocblas_operation transa, rocblas_operation transb, int m,
int n, int k, const void *alpha, const void *A,
rocblas_datatype Atype, int lda, const void *B,
rocblas_datatype Btype, int ldb, const void *beta,
void *C, rocblas_datatype Ctype, int ldc,
rocblas_datatype computeType) {
rocblas_gemm_algo algo = rocblas_gemm_algo_standard;
dev_ctx->TensorCoreCublasCallIfAvailable([&](rocblas_handle handle) {
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::rocblas_gemm_ex(
handle, transa, transb, m, n, k, alpha, A, Atype, lda, B, Btype, ldb,
beta, C, Ctype, ldc, C, Ctype, ldc, computeType, algo, 0, 0));
});
}
};
template <>
struct CUBlas<platform::complex128> {
using complex128 = platform::complex128;
static void GEMV(rocblas_handle handle, rocblas_operation transa, int m,
int n, const complex128 *alpha, const complex128 *A, int lda,
const complex128 *B, int ldb, const complex128 *beta,
complex128 *C, int ldc) {
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::rocblas_zgemv(
handle, transa, m, n,
reinterpret_cast<const rocblas_double_complex *>(alpha),
reinterpret_cast<const rocblas_double_complex *>(A), lda,
reinterpret_cast<const rocblas_double_complex *>(B), ldb,
reinterpret_cast<const rocblas_double_complex *>(beta),
reinterpret_cast<rocblas_double_complex *>(C), ldc));
}
static void AXPY(rocblas_handle handle, int n, const complex128 *alpha,
const complex128 *X, const int incX, complex128 *Y,
const int incY) {
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::rocblas_zaxpy(
handle, n, reinterpret_cast<const rocblas_double_complex *>(alpha),
reinterpret_cast<const rocblas_double_complex *>(X), incX,
reinterpret_cast<rocblas_double_complex *>(Y), incY));
}
static void GEMM_STRIDED_BATCH(rocblas_handle handle,
rocblas_operation transa,
rocblas_operation transb, int m, int n, int k,
const complex128 *alpha, const complex128 *A,
int lda, long long int strideA, // NOLINT
const complex128 *B, // NOLINT
int ldb, long long int strideB, // NOLINT
const complex128 *beta, complex128 *C, int ldc,
long long int strideC, // NOLINT
int batchCount) {
PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::rocblas_zgemm_strided_batched(
handle, transa, transb, m, n, k,
reinterpret_cast<const rocblas_double_complex *>(alpha),
reinterpret_cast<const rocblas_double_complex *>(A), lda, strideA,
reinterpret_cast<const rocblas_double_complex *>(B), ldb, strideB,
reinterpret_cast<const rocblas_double_complex *>(beta),
reinterpret_cast<rocblas_double_complex *>(C), ldc, strideC,
batchCount));
}
static void GEMM(rocblas_handle handle, rocblas_operation transa,
rocblas_operation transb, int m, int n, int k,
const complex128 *alpha, const complex128 *A, int lda,
const complex128 *B, int ldb, const complex128 *beta,
complex128 *C, int ldc) {
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::rocblas_zgemm(
handle, transa, transb, m, n, k,
reinterpret_cast<const rocblas_double_complex *>(alpha),
reinterpret_cast<const rocblas_double_complex *>(A), lda,
reinterpret_cast<const rocblas_double_complex *>(B), ldb,
reinterpret_cast<const rocblas_double_complex *>(beta),
reinterpret_cast<rocblas_double_complex *>(C), ldc));
}
// NOTES: GEMM_EX can use Tensor Core to accelerate matrix multiply.
// https://docs.nvidia.com/cuda/cublas/index.html#cublassetmathmode
template <typename... ARGS>
static void GEMM_EX(platform::CUDADeviceContext *dev_ctx,
rocblas_operation transa, rocblas_operation transb, int m,
int n, int k, const void *alpha, const void *A,
rocblas_datatype Atype, int lda, const void *B,
rocblas_datatype Btype, int ldb, const void *beta,
void *C, rocblas_datatype Ctype, int ldc,
rocblas_datatype computeType) {
rocblas_gemm_algo algo = rocblas_gemm_algo_standard;
dev_ctx->TensorCoreCublasCallIfAvailable([&](rocblas_handle handle) {
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::rocblas_gemm_ex(
handle, transa, transb, m, n, k, alpha, A, Atype, lda, B, Btype, ldb,
beta, C, Ctype, ldc, C, Ctype, ldc, computeType, algo, 0, 0));
});
}
};
template <>
template <typename T>
void Blas<platform::CUDADeviceContext>::GEMM(CBLAS_TRANSPOSE transA,
CBLAS_TRANSPOSE transB, int M,
int N, int K, T alpha, const T *A,
const T *B, T beta, T *C) const {
// Note that cublas follows fortran order, so the order is different from
// the cblas convention.
int lda = (transA == CblasNoTrans) ? K : M;
int ldb = (transB == CblasNoTrans) ? N : K;
rocblas_operation cuTransA = (transA == CblasNoTrans)
? rocblas_operation_none
: rocblas_operation_transpose;
rocblas_operation cuTransB = (transB == CblasNoTrans)
? rocblas_operation_none
: rocblas_operation_transpose;
context_.CublasCall([&](rocblas_handle handle) {
CUBlas<T>::GEMM(handle, cuTransB, cuTransA, N, M, K, &alpha, B, ldb, A, lda,
&beta, C, N);
});
}
template <>
template <>
inline void Blas<platform::CUDADeviceContext>::GEMM(
CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB, int M, int N, int K,
platform::float16 alpha, const platform::float16 *A,
const platform::float16 *B, platform::float16 beta,
platform::float16 *C) const {
// Note that cublas follows fortran order, so the order is different from
// the cblas convention.
int lda = (transA == CblasNoTrans) ? K : M;
int ldb = (transB == CblasNoTrans) ? N : K;
rocblas_operation cuTransA = (transA == CblasNoTrans)
? rocblas_operation_none
: rocblas_operation_transpose;
rocblas_operation cuTransB = (transB == CblasNoTrans)
? rocblas_operation_none
: rocblas_operation_transpose;
// TODO(kexinzhao): add processing code for compute capability < 53 case
PADDLE_ENFORCE_GE(
context_.GetComputeCapability(), 53,
platform::errors::InvalidArgument(
"cublas fp16 gemm requires GPU compute capability >= 53,"
"but received %d",
context_.GetComputeCapability()));
float h_alpha = static_cast<float>(alpha);
float h_beta = static_cast<float>(beta);
auto &cuda_ctx = const_cast<platform::CUDADeviceContext &>(context_);
CUBlas<platform::float16>::GEMM_EX(
&cuda_ctx, cuTransB, cuTransA, N, M, K, &h_alpha, B,
rocblas_datatype_f16_r, ldb, A, rocblas_datatype_f16_r, lda, &h_beta, C,
rocblas_datatype_f16_r, N, rocblas_datatype_f32_r);
}
template <>
template <>
inline void Blas<platform::CUDADeviceContext>::GEMM(
CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB, int M, int N, int K,
platform::complex64 alpha, const platform::complex64 *A,
const platform::complex64 *B, platform::complex64 beta,
platform::complex64 *C) const {
// Note that cublas follows fortran order, so the order is different from
// the cblas convention.
int lda = (transA == CblasNoTrans) ? K : M;
int ldb = (transB == CblasNoTrans) ? N : K;
rocblas_operation cuTransA = (transA == CblasNoTrans)
? rocblas_operation_none
: rocblas_operation_transpose;
rocblas_operation cuTransB = (transB == CblasNoTrans)
? rocblas_operation_none
: rocblas_operation_transpose;
// TODO(kexinzhao): add processing code for compute capability < 53 case
PADDLE_ENFORCE_GE(
context_.GetComputeCapability(), 53,
platform::errors::InvalidArgument(
"cublas complex64 gemm requires GPU compute capability >= 53,"
"but received %d",
context_.GetComputeCapability()));
thrust::complex<float> c_alpha =
thrust::complex<float>(alpha.real, alpha.imag);
thrust::complex<float> c_beta = thrust::complex<float>(beta.real, beta.imag);
auto &cuda_ctx = const_cast<platform::CUDADeviceContext &>(context_);
CUBlas<platform::complex64>::GEMM_EX(
&cuda_ctx, cuTransB, cuTransA, N, M, K, &c_alpha, B,
rocblas_datatype_f32_c, ldb, A, rocblas_datatype_f32_c, lda, &c_beta, C,
rocblas_datatype_f32_c, N, rocblas_datatype_f32_c);
}
template <>
template <>
inline void Blas<platform::CUDADeviceContext>::GEMM(
CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB, int M, int N, int K,
platform::complex128 alpha, const platform::complex128 *A,
const platform::complex128 *B, platform::complex128 beta,
platform::complex128 *C) const {
// Note that cublas follows fortran order, so the order is different from
// the cblas convention.
int lda = (transA == CblasNoTrans) ? K : M;
int ldb = (transB == CblasNoTrans) ? N : K;
rocblas_operation cuTransA = (transA == CblasNoTrans)
? rocblas_operation_none
: rocblas_operation_transpose;
rocblas_operation cuTransB = (transB == CblasNoTrans)
? rocblas_operation_none
: rocblas_operation_transpose;
// TODO(kexinzhao): add processing code for compute capability < 53 case
PADDLE_ENFORCE_GE(
context_.GetComputeCapability(), 53,
platform::errors::InvalidArgument(
"cublas complex128 gemm requires GPU compute capability >= 53,"
"but received %d",
context_.GetComputeCapability()));
thrust::complex<double> c_alpha =
thrust::complex<double>(alpha.real, alpha.imag);
thrust::complex<double> c_beta =
thrust::complex<double>(beta.real, beta.imag);
auto &cuda_ctx = const_cast<platform::CUDADeviceContext &>(context_);
CUBlas<platform::complex128>::GEMM_EX(
&cuda_ctx, cuTransB, cuTransA, N, M, K, &c_alpha, B,
rocblas_datatype_f64_c, ldb, A, rocblas_datatype_f64_c, lda, &c_beta, C,
rocblas_datatype_f64_c, N, rocblas_datatype_f64_c);
}
template <>
template <typename T>
void Blas<platform::CUDADeviceContext>::GEMM(bool transA, bool transB, int M,
int N, int K, T alpha, const T *A,
int lda, const T *B, int ldb,
T beta, T *C, int ldc) const {
// Note that cublas follows fortran order, so the order is different from
// the cblas convention.
rocblas_operation cuTransA =
transA ? rocblas_operation_transpose : rocblas_operation_none;
rocblas_operation cuTransB =
transB ? rocblas_operation_transpose : rocblas_operation_none;
context_.CublasCall([&](rocblas_handle handle) {
CUBlas<T>::GEMM(handle, cuTransB, cuTransA, N, M, K, &alpha, B, ldb, A, lda,
&beta, C, ldc);
});
}
template <>
template <>
inline void Blas<platform::CUDADeviceContext>::GEMM(
bool transA, bool transB, int M, int N, int K, platform::float16 alpha,
const platform::float16 *A, int lda, const platform::float16 *B, int ldb,
platform::float16 beta, platform::float16 *C, int ldc) const {
// Note that cublas follows fortran order, so the order is different from
// the cblas convention.
rocblas_operation cuTransA =
transA ? rocblas_operation_transpose : rocblas_operation_none;
rocblas_operation cuTransB =
transB ? rocblas_operation_transpose : rocblas_operation_none;
context_.CublasCall([&](rocblas_handle handle) {
CUBlas<platform::float16>::GEMM(handle, cuTransB, cuTransA, N, M, K, &alpha,
B, ldb, A, lda, &beta, C, ldc);
});
}
template <>
template <typename T>
void Blas<platform::CUDADeviceContext>::AXPY(int n, T alpha, const T *x,
T *y) const {
context_.CublasCall([&](rocblas_handle handle) {
CUBlas<T>::AXPY(handle, n, &alpha, x, 1, y, 1);
});
}
template <>
template <typename T>
void Blas<platform::CUDADeviceContext>::SCAL(int n, const T alpha, T *x) const {
context_.CublasCall(
[&](rocblas_handle handle) { CUBlas<T>::SCAL(handle, n, &alpha, x, 1); });
}
template <>
template <typename T>
void Blas<platform::CUDADeviceContext>::VCOPY(int n, const T *x, T *y) const {
context_.CublasCall(
[&](rocblas_handle handle) { CUBlas<T>::VCOPY(handle, n, x, 1, y, 1); });
}
template <>
template <typename T>
void Blas<platform::CUDADeviceContext>::GEMV(bool trans_a, int M, int N,
T alpha, const T *A, const T *B,
T beta, T *C) const {
rocblas_operation cuTransA =
!trans_a ? rocblas_operation_transpose : rocblas_operation_none;
context_.CublasCall([&](rocblas_handle handle) {
CUBlas<T>::GEMV(handle, cuTransA, N, M, &alpha, A, N, B, 1, &beta, C, 1);
});
}
template <>
template <>
inline void Blas<platform::CUDADeviceContext>::GEMV(
bool trans_a, int M, int N, platform::float16 alpha,
const platform::float16 *A, const platform::float16 *B,
platform::float16 beta, platform::float16 *C) const {
// Because cublas doesn't support half gemv, we use cublasHgemm to achieve it.
if (trans_a) {
this->template GEMM<platform::float16>(CblasNoTrans, CblasNoTrans, 1, N, M,
alpha, B, A, beta, C);
} else {
this->template GEMM<platform::float16>(CblasNoTrans, CblasNoTrans, M, 1, N,
alpha, A, B, beta, C);
}
}
template <>
template <typename T>
void Blas<platform::CUDADeviceContext>::BatchedGEMM(
CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB, int M, int N, int K,
T alpha, const T *A, const T *B, T beta, T *C, int batchCount,
int64_t strideA, int64_t strideB) const {
// Note that cublas follows fortran order, so the order is different from
// the cblas convention.
int lda = (transA == CblasNoTrans) ? K : M;
int ldb = (transB == CblasNoTrans) ? N : K;
int ldc = N;
rocblas_operation cuTransA = (transA == CblasNoTrans)
? rocblas_operation_none
: rocblas_operation_transpose;
rocblas_operation cuTransB = (transB == CblasNoTrans)
? rocblas_operation_none
: rocblas_operation_transpose;
const int64_t strideC = M * N;
context_.CublasCall([&](rocblas_handle handle) {
CUBlas<T>::GEMM_STRIDED_BATCH(handle, cuTransB, cuTransA, N, M, K, &alpha,
B, ldb, strideB, A, lda, strideA, &beta, C,
ldc, strideC, batchCount);
});
}
template <>
template <typename T>
void Blas<platform::CUDADeviceContext>::BatchedGEMM(
CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB, int M, int N, int K,
T alpha, const T **A, const T **B, T beta, T **C, int batchCount) const {
for (int k = 0; k < batchCount; ++k) {
this->template GEMM<T>(transA, transB, M, N, K, alpha, A[k], B[k], beta,
C[k]);
}
}
template <>
template <>
inline void Blas<platform::CUDADeviceContext>::BatchedGEMM(
CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB, int M, int N, int K,
platform::float16 alpha, const platform::float16 **A,
const platform::float16 **B, platform::float16 beta, platform::float16 **C,
int batchCount) const {
for (int k = 0; k < batchCount; ++k) {
this->template GEMM<platform::float16>(transA, transB, M, N, K, alpha, A[k],
B[k], beta, C[k]);
}
}
template <>
template <typename T>
void Blas<platform::CUDADeviceContext>::TRSM(CBLAS_SIDE side, CBLAS_UPLO uplo,
CBLAS_TRANSPOSE transA,
CBLAS_DIAG diag, int M, int N,
T alpha, const T *A, int lda, T *B,
int ldb) const {
// solve row major `op ( A ) X = α B` by taking it as `X' op ( A' ) = α B'`
// where ' stands for transpose
rocblas_side cuSide =
(side == CblasLeft) ? rocblas_side_right : rocblas_side_left;
rocblas_fill cuUplo =
(uplo == CblasLower) ? rocblas_fill_upper : rocblas_fill_lower;
// use CUBLAS_OP_C (conjugate transpose) for complex
rocblas_operation cuTransA = (transA == CblasNoTrans)
? rocblas_operation_none
: rocblas_operation_transpose;
rocblas_diagonal cuDiag =
(diag == CblasUnit) ? rocblas_diagonal_unit : rocblas_diagonal_non_unit;
context_.CublasCall([&](rocblas_handle handle) {
CUBlas<T>::TRSM(handle, cuSide, cuUplo, cuTransA, cuDiag, N, M, &alpha, A,
lda, B, ldb);
});
}
template <>
template <typename T>
void Blas<platform::CUDADeviceContext>::BatchedGETRF(int n, T **a, int *ipiv,
int *info,
int batch_size) const {
context_.CublasCall([&](rocblas_handle handle) {
CUBlas<T>::GETRF_BATCH(handle, n, a, n, ipiv, info, batch_size);
});
}
template <>
template <typename T>
void Blas<platform::CUDADeviceContext>::BatchedGETRI(int n, const T **a,
const int *ipiv, T **a_inv,
int *info,
int batch_size) const {
PADDLE_ENFORCE_NE(
a_inv, a,
platform::errors::InvalidArgument(
"cuBLAS fuction 'cublas<S/D>getrfBatched' cannot be executed "
"in-place. The memory space of output matrix (address: %p) cannot "
"overlap memory space of input matrix (address: %p).",
a_inv, a));
context_.CublasCall([&](rocblas_handle handle) {
CUBlas<T>::GETRI_BATCH(handle, n, a, n, ipiv, a_inv, n, info, batch_size);
});
}
template <>
template <typename T>
void Blas<platform::CUDADeviceContext>::BatchedMatInv(int n, const T **a,
T **a_inv, int *info,
int batch_size) const {
context_.CublasCall([&](rocblas_handle handle) {
CUBlas<T>::MATINV_BATCH(handle, n, a, n, a_inv, n, info, batch_size);
});
}
} // namespace math
} // namespace operators
} // namespace paddle
......@@ -28,8 +28,7 @@ template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenVector = framework::EigenVector<T, MajorType, IndexType>;
#ifndef __NVCC__
#if !defined(__NVCC__) && !defined(__HIPCC___) // @{ Group for GRU CPU
template <class OpResetOutput, typename T>
void hl_naive_gru_forward_reset_output(
OpResetOutput op_reset_output, T *gate_value, T *reset_output_value,
......@@ -799,7 +798,7 @@ inline void cpu_gru_backward(const platform::CPUDeviceContext &context,
}
}
#endif
#endif // @} End Group for GRU CPU
} // namespace detail
} // namespace math
......
......@@ -42,7 +42,7 @@ class gru_resetOutput {
(*value_reset_output + *value_reset_bias) * (*value_reset_gate);
}
}
#ifndef __NVCC__
#if !defined(__NVCC__) && !defined(__HIPCC___) // @{ Group GRU reset output
#ifndef __AVX__
static const bool avx = false;
#else
......@@ -65,7 +65,7 @@ class gru_resetOutput {
}
}
#endif
#endif
#endif // @} End Group GRU reset output
};
template <typename T>
......@@ -84,7 +84,7 @@ class gru_finalOutput {
((*value_update_gate) * (*value_frame_state));
}
}
#ifndef __NVCC__
#if !defined(__NVCC__) && !defined(__HIPCC___) // @{ Group GRU final output
#ifndef __AVX__
static const bool avx = false;
#else
......@@ -107,7 +107,7 @@ class gru_finalOutput {
}
}
#endif
#endif
#endif // @} End Group GRU final output
};
} // namespace forward
......@@ -137,7 +137,7 @@ class gru_stateGrad {
*value_frame_state, act_input);
}
}
#ifndef __NVCC__
#if !defined(__NVCC__) && !defined(__HIPCC___) // @{ Group GRU state grad
#ifndef __AVX__
static const bool avx = false;
#else
......@@ -170,7 +170,7 @@ class gru_stateGrad {
}
}
#endif
#endif
#endif // @} End Group GRU state grad
};
template <typename T>
......@@ -187,7 +187,7 @@ class gru_resetGrad {
*grad_reset_gate =
activation(*grad_reset_gate, *value_reset_gate, act_gate);
}
#ifndef __NVCC__
#if !defined(__NVCC__) && !defined(__HIPCC___) // @{ Group GRU reset grad
#ifndef __AVX__
static const bool avx = false;
#else
......@@ -206,7 +206,7 @@ class gru_resetGrad {
activation(*grad_reset_gate, *value_reset_gate, act_gate);
}
#endif
#endif
#endif // @} End Group GRU reset grad
};
template <typename T>
class gru {
......@@ -230,7 +230,7 @@ class gru {
*value_reset_gate, act_gate);
*grad_reset_output = (*value_reset_gate) * (*grad_frame_state);
}
#ifndef __NVCC__
#if !defined(__NVCC__) && !defined(__HIPCC___) // @{ Group GRU CPU
#ifndef __AVX__
static const bool avx = false;
#else
......@@ -261,7 +261,7 @@ class gru {
*grad_reset_output = _mm256_mul_ps(*value_reset_gate, *grad_frame_state);
}
#endif
#endif
#endif // @} End Group GRU CPU
};
} // namespace backward
......
......@@ -35,7 +35,7 @@ template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenVector = framework::EigenVector<T, MajorType, IndexType>;
#ifndef __NVCC__
#if !defined(__NVCC__) && !defined(__HIPCC___) // @{ Group LSTM CPU
template <class T, class Op>
void naive_lstm_forward_one_sequence(Op op, LstmMetaValue<T> value,
......@@ -467,7 +467,7 @@ void cpu_lstm_backward(const platform::CPUDeviceContext &context, Op op,
}
}
#endif
#endif // @{ End Group LSTM CPU
} // namespace detail
} // namespace math
......
......@@ -50,7 +50,7 @@ class lstm {
*state_atv = activation(*state, active_state);
*output = (*value_og) * (*state_atv);
}
#ifndef __NVCC__
#if !defined(__NVCC__) && !defined(__HIPCC___) // @{ Group LSTM FWD
#ifndef __AVX__ // If not compiled with AVX instructs. Disable AVX by default
static const bool avx = false;
#else
......@@ -87,7 +87,7 @@ class lstm {
*output = _mm256_mul_ps(*value_og, *state_atv);
}
#endif
#endif
#endif // @} End Group LSTM FWD
};
} // namespace forward
......@@ -132,7 +132,7 @@ class lstm {
*checkFGrad = (*grad_fg) * (*prev_state);
*checkOGrad = (*grad_og) * (*state);
}
#ifndef __NVCC__
#if !defined(__NVCC__) && !defined(__HIPCC___) // @{ Group LSTM BWD
#ifndef __AVX__ // If not compiled with AVX instructs. Disable AVX by default
static const bool avx = false;
#else
......@@ -177,7 +177,7 @@ class lstm {
*checkOGrad = _mm256_mul_ps(*grad_og, *state);
}
#endif
#endif
#endif // @} End Group LSTM BWD
};
} // namespace backward
......
......@@ -39,7 +39,7 @@ BufferedReader::BufferedReader(
buffer_size_(buffer_size),
pin_memory_(pin_memory) {
VLOG(1) << "BufferedReader";
#ifdef PADDLE_WITH_CUDA
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
if (platform::is_gpu_place(place_) && !pin_memory) {
int dev_idx = BOOST_GET_CONST(platform::CUDAPlace, place_).device;
compute_stream_ =
......@@ -74,7 +74,7 @@ void BufferedReader::ReadAsync(size_t i) {
return -1UL;
}
#ifdef PADDLE_WITH_CUDA
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) // @{ Group GPU Place
if (platform::is_gpu_place(place_)) {
TensorVec &cuda = cuda_buffer_[i];
if (cuda.empty()) {
......@@ -142,10 +142,17 @@ void BufferedReader::ReadAsync(size_t i) {
// cuda memory immediately without waiting cuda kernel ends
platform::SetDeviceId(
BOOST_GET_CONST(platform::CUDAPlace, place_).device);
#ifdef PADDLE_WITH_HIP
PADDLE_ENFORCE_CUDA_SUCCESS(
hipEventRecord(events_[i].get(), compute_stream_));
PADDLE_ENFORCE_CUDA_SUCCESS(
hipStreamWaitEvent(stream_.get(), events_[i].get(), 0));
#else
PADDLE_ENFORCE_CUDA_SUCCESS(
cudaEventRecord(events_[i].get(), compute_stream_));
PADDLE_ENFORCE_CUDA_SUCCESS(
cudaStreamWaitEvent(stream_.get(), events_[i].get(), 0));
#endif
platform::RecordEvent record_event("BufferedReader:MemoryCopy");
for (size_t i = 0; i < cpu.size(); ++i) {
......@@ -174,14 +181,22 @@ void BufferedReader::ReadAsync(size_t i) {
memory::Copy(BOOST_GET_CONST(platform::CUDAPlace, place_), gpu_ptr,
cuda_pinned_place, cuda_pinned_ptr, size,
stream_.get());
#ifdef PADDLE_WITH_HIP
PADDLE_ENFORCE_CUDA_SUCCESS(hipStreamSynchronize(stream_.get()));
#else
PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamSynchronize(stream_.get()));
#endif
}
cuda[i].set_lod(cpu[i].lod());
}
#ifdef PADDLE_WITH_HIP
PADDLE_ENFORCE_CUDA_SUCCESS(hipStreamSynchronize(stream_.get()));
#else
PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamSynchronize(stream_.get()));
#endif
}
}
#endif
#endif // @} End Group GPU Place
return i;
}));
}
......
......@@ -21,7 +21,7 @@
#include "ThreadPool.h"
#include "paddle/fluid/framework/reader.h"
#ifdef PADDLE_WITH_CUDA
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
#include "paddle/fluid/platform/cuda_resource_pool.h"
#include "paddle/fluid/platform/gpu_info.h"
#endif
......@@ -68,8 +68,8 @@ class BufferedReader : public framework::DecoratedReader {
std::vector<TensorVec> cpu_buffer_;
std::vector<TensorVec> cuda_buffer_;
size_t prev_pos_{-1UL};
#ifdef PADDLE_WITH_CUDA
cudaStream_t compute_stream_;
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
gpuStream_t compute_stream_;
std::shared_ptr<platform::CudaStreamObject> stream_;
std::vector<std::shared_ptr<platform::CudaEventObject>> events_;
#endif
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册