未验证 提交 84639b61 编写于 作者: Q Qi Li 提交者: GitHub

[ROCM] update fluid operators for rocm (part3), test=develop (#31213)

* [ROCM] update fluid operators for rocm (part3), test=develop

* fix clang format error, test=develop
上级 3b9db171
......@@ -24,22 +24,28 @@ file(APPEND ${pybind_file} "USE_CPU_ONLY_OP(fusion_gru);\n")
file(APPEND ${pybind_file} "USE_CPU_ONLY_OP(fusion_lstm);\n")
if (WITH_GPU)
if (WITH_GPU OR WITH_ROCM)
# fused_bn_activation_op needs cudnn 7.4.1 above
if (NOT ${CUDNN_VERSION} VERSION_LESS 7401)
# HIP not support bn act fuse in MIOPEN
if ((NOT WITH_ROCM) AND (NOT ${CUDNN_VERSION} VERSION_LESS 7401))
op_library(fused_bn_activation_op)
file(APPEND ${pybind_file} "USE_CUDA_ONLY_OP(fused_batch_norm_act);\n")
endif()
# conv_fusion_op needs cudnn 7 above
if (NOT ${CUDNN_VERSION} VERSION_LESS 7100)
# HIP not support cudnnConvolutionBiasActivationForward
if ((NOT WITH_ROCM) AND (NOT ${CUDNN_VERSION} VERSION_LESS 7100))
op_library(conv_fusion_op)
file(APPEND ${pybind_file} "USE_CUDA_ONLY_OP(conv2d_fusion);\n")
endif()
# fusion_transpose_flatten_concat_op
op_library(fusion_transpose_flatten_concat_op)
file(APPEND ${pybind_file} "USE_CUDA_ONLY_OP(fusion_transpose_flatten_concat);\n")
# HIP not support cudnnTransformTensor
if(NOT WITH_ROCM)
op_library(fusion_transpose_flatten_concat_op)
file(APPEND ${pybind_file} "USE_CUDA_ONLY_OP(fusion_transpose_flatten_concat);\n")
endif()
# fusion_conv_inception_op needs cudnn 7 above
if (NOT ${CUDNN_VERSION} VERSION_LESS 7100)
# HIP not support cudnnConvolutionBiasActivationForward
if ((NOT WITH_ROCM) AND (NOT ${CUDNN_VERSION} VERSION_LESS 7100))
op_library(fusion_conv_inception_op)
file(APPEND ${pybind_file} "USE_CUDA_ONLY_OP(conv2d_inception_fusion);\n")
endif()
......@@ -60,8 +66,9 @@ if (WITH_GPU)
cc_test(test_fusion_group_op SRCS fusion_group_op_test.cc DEPS fusion_group_op)
endif()
# fused_bn_add_activation
if (NOT ${CUDNN_VERSION} VERSION_LESS 7401)
op_library(fused_bn_add_activation_op)
file(APPEND ${pybind_file} "USE_CUDA_ONLY_OP(fused_bn_add_activation);\n")
# HIP not support bn act fuse in MIOPEN
if ((NOT WITH_ROCM) AND (NOT ${CUDNN_VERSION} VERSION_LESS 7401))
op_library(fused_bn_add_activation_op)
file(APPEND ${pybind_file} "USE_CUDA_ONLY_OP(fused_bn_add_activation);\n")
endif()
endif()
......@@ -12,10 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include <cuda_runtime.h>
#include <paddle/fluid/platform/device_context.h>
#include <algorithm>
#include <cub/cub.cuh> // NOLINT
#include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/memory/malloc.h"
......@@ -39,7 +37,11 @@ class EmbeddingEltWiseLayerNormKernel : public framework::OpKernel<T> {
in_embs_(framework::proto::VarType::INT64);
framework::DDim in_dim{input_num};
int device_id;
#ifdef PADDLE_WITH_HIP
hipGetDevice(&device_id);
#else
cudaGetDevice(&device_id);
#endif
in_ids_.Resize(in_dim);
in_embs_.Resize(in_dim);
int64_t *in_ids_d =
......@@ -52,11 +54,17 @@ class EmbeddingEltWiseLayerNormKernel : public framework::OpKernel<T> {
in1s.push_back(reinterpret_cast<uintptr_t>(ids[i]->data<int64_t>()));
in2s.push_back(reinterpret_cast<uintptr_t>(embs[i]->data<T>()));
}
#ifdef PADDLE_WITH_HIP
hipMemcpyAsync(in_ids_d, in1s.data(), sizeof(int64_t) * input_num,
hipMemcpyHostToDevice, device_ctx.stream());
hipMemcpyAsync(in_embs_d, in2s.data(), sizeof(int64_t) * input_num,
hipMemcpyHostToDevice, device_ctx.stream());
#else
cudaMemcpyAsync(in_ids_d, in1s.data(), sizeof(int64_t) * input_num,
cudaMemcpyHostToDevice, device_ctx.stream());
cudaMemcpyAsync(in_embs_d, in2s.data(), sizeof(int64_t) * input_num,
cudaMemcpyHostToDevice, device_ctx.stream());
#endif
auto *bias = context.Input<framework::Tensor>("Bias");
auto *scale = context.Input<framework::Tensor>("Scale");
......
......@@ -12,7 +12,14 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef __NVCC__
#include <cub/cub.cuh>
#endif
#ifdef __HIPCC__
#include <hipcub/hipcub.hpp>
namespace cub = hipcub;
#endif
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/platform/cuda_device_function.h"
......
......@@ -12,7 +12,6 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include <cuda_runtime.h>
#include <paddle/fluid/platform/device_context.h>
#include <algorithm>
#include "paddle/fluid/framework/op_registry.h"
......@@ -89,7 +88,7 @@ __global__ void TransposeQkvKernel(const int H, const T *input, const T *bias,
void TransQKVWithBias(const int batch, const int seq_len, const int head_size,
const int head_num, const float *input, const float *bias,
float *output, cudaStream_t stream) {
float *output, gpuStream_t stream) {
// BxSx3xNxH + 3xNxH -> 3xBxNxSxH
int scratch_size = batch * head_num * seq_len * seq_len;
const dim3 grid(seq_len, batch, 3);
......
......@@ -12,7 +12,6 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include <cuda_runtime.h>
#include <paddle/fluid/platform/device_context.h>
#include <algorithm>
#include "paddle/fluid/framework/op_registry.h"
......
......@@ -83,7 +83,7 @@ class LiteEngineOp : public framework::OperatorBase {
<< engine_->GetInputNames()[i] << ")";
inference::lite::utils::TensorCopy(&dst_t, &src_t, *ctx, zero_copy_);
}
#ifdef PADDLE_WITH_CUDA
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
if (platform::is_gpu_place(dev_place)) {
platform::GpuStreamSync(
static_cast<const platform::CUDADeviceContext *>(ctx)->stream());
......@@ -101,7 +101,7 @@ class LiteEngineOp : public framework::OperatorBase {
<< engine_->GetOutputNames()[i] << ")";
inference::lite::utils::TensorCopy(dst_t, &src_t, *ctx, zero_copy_);
}
#ifdef PADDLE_WITH_CUDA
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
if (platform::is_gpu_place(dev_place)) {
platform::GpuStreamSync(
static_cast<const platform::CUDADeviceContext *>(ctx)->stream());
......
......@@ -67,7 +67,7 @@ TEST(LiteEngineOp, engine_op) {
*block_->add_ops() = *elt_add->Proto();
*block_->add_ops() = *fetch->Proto();
framework::Scope scope;
#ifdef PADDLE_WITH_CUDA
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
platform::CUDAPlace place;
platform::CUDADeviceContext ctx(place);
#else
......@@ -84,11 +84,11 @@ TEST(LiteEngineOp, engine_op) {
std::vector<std::string> repetitive_params{"x", "y"};
inference::lite::EngineConfig config;
config.valid_places = {
#ifdef PADDLE_WITH_CUDA
paddle::lite_api::Place({TARGET(kCUDA), PRECISION(kFloat)}),
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
paddle::lite_api::Place({TARGET(kCUDA), PRECISION(kFloat)}),
#endif
paddle::lite_api::Place({TARGET(kX86), PRECISION(kFloat)}),
paddle::lite_api::Place({TARGET(kHost), PRECISION(kAny)}),
paddle::lite_api::Place({TARGET(kX86), PRECISION(kFloat)}),
paddle::lite_api::Place({TARGET(kHost), PRECISION(kAny)}),
};
serialize_params(&(config.param), &scope, repetitive_params);
config.model = program.Proto()->SerializeAsString();
......
......@@ -55,7 +55,7 @@ void AddFetchListToBlockDesc(framework::proto::BlockDesc* block,
void serialize_params(std::string* str, framework::Scope* scope,
const std::vector<std::string>& params) {
std::ostringstream os;
#ifdef PADDLE_WITH_CUDA
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
platform::CUDAPlace place;
platform::CUDADeviceContext ctx(place);
#else
......@@ -106,7 +106,7 @@ void CreateTensor(framework::Scope* scope, const std::string& name,
tensor->Resize(dims);
platform::Place place;
if (in_cuda) {
#ifdef PADDLE_WITH_CUDA
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
place = platform::CUDAPlace(0);
#else
PADDLE_THROW(platform::errors::PreconditionNotMet(
......
......@@ -41,7 +41,7 @@ HOSTDEVICE inline int64_t BinarySearch(const T *x, int64_t num, const T &val) {
template <typename T>
HOSTDEVICE inline size_t LowerBound(const T *x, size_t num, const T &val) {
#ifdef __CUDA_ARCH__
#if defined(__CUDA_ARCH__) || defined(__HIPCC__) // @{ Group LowerBound
// The following code is from
// https://en.cppreference.com/w/cpp/algorithm/lower_bound
auto *first = x;
......@@ -59,12 +59,12 @@ HOSTDEVICE inline size_t LowerBound(const T *x, size_t num, const T &val) {
return static_cast<size_t>(first - x);
#else
return static_cast<size_t>(std::lower_bound(x, x + num, val) - x);
#endif
#endif // @} End Group LowerBound
}
template <typename T>
HOSTDEVICE inline size_t UpperBound(const T *x, size_t num, const T &val) {
#ifdef __CUDA_ARCH__
#if defined(__CUDA_ARCH__) || defined(__HIPCC__) // @{ Group UpperBound
// The following code is from
// https://en.cppreference.com/w/cpp/algorithm/upper_bound
auto *first = x;
......@@ -82,7 +82,7 @@ HOSTDEVICE inline size_t UpperBound(const T *x, size_t num, const T &val) {
return static_cast<size_t>(first - x);
#else
return static_cast<size_t>(std::upper_bound(x, x + num, val) - x);
#endif
#endif // @} End Group UpperBound
}
} // namespace math
......
......@@ -134,7 +134,7 @@ TEST(BeamSearch, CPU) {
paddle::platform::CPUPlace>();
}
#ifdef PADDLE_WITH_CUDA
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
TEST(BeamSearch, GPU) {
TestBeamSearch<paddle::platform::CUDADeviceContext,
paddle::platform::CUDAPlace>();
......
......@@ -102,7 +102,7 @@ class Blas {
T alpha, const T* A, int lda, const T* B, int ldb, T beta, T* C,
int ldc) const;
#ifdef PADDLE_WITH_MKLML
#ifdef PADDLE_WITH_MKLML // @{ Group MKLML: class Blas
template <typename T>
T* GEMM_ALLOC(const CBLAS_IDENTIFIER id, const int M, const int N,
const int K) const;
......@@ -126,7 +126,7 @@ class Blas {
const int* indx, const int* pntrb, const int* pntre, const T* b,
const int* ldb, const T* beta, T* c, const int* ldc) const;
#if !defined(PADDLE_WITH_CUDA)
#if !defined(PADDLE_WITH_CUDA) && !defined(PADDLE_WITH_HIP)
template <typename T>
void MatMulWithHead(const framework::Tensor& mat_a,
const MatDescriptor& dim_a,
......@@ -135,7 +135,7 @@ class Blas {
framework::Tensor* mat_out, T beta,
bool mat_y_split_vertical) const;
#endif
#endif
#endif // @} End Group MKLML: class Blas
template <typename T>
void MatMul(const int M, const int N, const int K, const T* A, const T* B,
......@@ -210,7 +210,8 @@ class Blas {
int K, T alpha, const T** A, const T** B, T beta, T** C,
int batchCount) const;
#if defined(PADDLE_WITH_MKLML) && !defined(PADDLE_WITH_CUDA)
#if defined(PADDLE_WITH_MKLML) && !defined(PADDLE_WITH_CUDA) && \
!defined(PADDLE_WITH_HIP)
template <typename T>
void BatchedGEMMWithHead(CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB,
int W1, int H1, int W2, int H2, T alpha, const T* A,
......@@ -235,7 +236,7 @@ class Blas {
CBLAS_DIAG diag, int M, int N, T alpha, const T* A, int lda, T* B,
int ldb) const;
#ifdef PADDLE_WITH_CUDA
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
template <typename T>
void BatchedGETRF(int n, T** a, int* ipiv, int* info, int batch_size) const;
......@@ -262,7 +263,7 @@ class BlasT : private Blas<DeviceContext> {
Base()->template GEMM<T>(args...);
}
#ifdef PADDLE_WITH_MKLML
#ifdef PADDLE_WITH_MKLML // @{ Group MKLML: class BlasT
template <typename... ARGS>
T* GEMM_ALLOC(ARGS... args) const {
return Base()->template GEMM_ALLOC<T>(args...);
......@@ -288,13 +289,13 @@ class BlasT : private Blas<DeviceContext> {
Base()->template CSRMM<T>(args...);
}
#if !defined(PADDLE_WITH_CUDA)
#if !defined(PADDLE_WITH_CUDA) && !defined(PADDLE_WITH_HIP)
template <typename... ARGS>
void MatMulWithHead(ARGS... args) const {
Base()->template MatMulWithHead<T>(args...);
}
#endif
#endif
#endif // @} End Group MKLML: class BlasT
template <typename... ARGS>
void MatMul(ARGS... args) const {
......@@ -386,7 +387,7 @@ class BlasT : private Blas<DeviceContext> {
Base()->template TRSM<T>(args...);
}
#ifdef PADDLE_WITH_CUDA
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
template <typename... ARGS>
void BatchedGETRF(ARGS... args) const {
Base()->template BatchedGETRF<T>(args...);
......@@ -429,3 +430,6 @@ inline BlasT<DeviceContext, T> GetBlas(const DeviceContext& dev_ctx) {
#ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/operators/math/blas_impl.cu.h"
#endif
#ifdef PADDLE_WITH_HIP
#include "paddle/fluid/operators/math/blas_impl.hip.h"
#endif
......@@ -1046,7 +1046,8 @@ void Blas<platform::CPUDeviceContext>::BatchedGEMM(
#endif
}
#if defined(PADDLE_WITH_MKLML) && !defined(PADDLE_WITH_CUDA)
#if defined(PADDLE_WITH_MKLML) && !defined(PADDLE_WITH_CUDA) && \
!defined(PADDLE_WITH_HIP) // @{ Group Blas MKLML: BatchedGEMMWithHead
template <>
template <typename T>
void Blas<platform::CPUDeviceContext>::BatchedGEMMWithHead(
......@@ -1116,7 +1117,7 @@ void Blas<platform::CPUDeviceContext>::BatchedGEMMWithHead(
}
}
}
#endif
#endif // @} End Group Blas MKLML: BatchedGEMMWithHead
template <typename DeviceContext>
template <typename T>
......@@ -1192,7 +1193,9 @@ void Blas<DeviceContext>::MatMul(const framework::Tensor &mat_a,
}
}
#if defined(PADDLE_WITH_MKLML) && !defined(PADDLE_WITH_CUDA)
#if defined(PADDLE_WITH_MKLML) && !defined(PADDLE_WITH_CUDA) && \
!defined(PADDLE_WITH_HIP)
// @{ Group Blas MKLML: MatMulWithHead
/*
* Multiple two matrixes with multiple heads
*
......@@ -1319,7 +1322,7 @@ void Blas<DeviceContext>::MatMulWithHead(const framework::Tensor &mat_a,
dim_a.stride_, dim_b.stride_, head_number, mat_b_split_vertical);
}
}
#endif
#endif // @} End Group Blas MKLML: MatMulWithHead
template <typename DeviceContext>
template <typename T>
......
此差异已折叠。
......@@ -28,8 +28,7 @@ template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenVector = framework::EigenVector<T, MajorType, IndexType>;
#ifndef __NVCC__
#if !defined(__NVCC__) && !defined(__HIPCC___) // @{ Group for GRU CPU
template <class OpResetOutput, typename T>
void hl_naive_gru_forward_reset_output(
OpResetOutput op_reset_output, T *gate_value, T *reset_output_value,
......@@ -799,7 +798,7 @@ inline void cpu_gru_backward(const platform::CPUDeviceContext &context,
}
}
#endif
#endif // @} End Group for GRU CPU
} // namespace detail
} // namespace math
......
......@@ -42,7 +42,7 @@ class gru_resetOutput {
(*value_reset_output + *value_reset_bias) * (*value_reset_gate);
}
}
#ifndef __NVCC__
#if !defined(__NVCC__) && !defined(__HIPCC___) // @{ Group GRU reset output
#ifndef __AVX__
static const bool avx = false;
#else
......@@ -65,7 +65,7 @@ class gru_resetOutput {
}
}
#endif
#endif
#endif // @} End Group GRU reset output
};
template <typename T>
......@@ -84,7 +84,7 @@ class gru_finalOutput {
((*value_update_gate) * (*value_frame_state));
}
}
#ifndef __NVCC__
#if !defined(__NVCC__) && !defined(__HIPCC___) // @{ Group GRU final output
#ifndef __AVX__
static const bool avx = false;
#else
......@@ -107,7 +107,7 @@ class gru_finalOutput {
}
}
#endif
#endif
#endif // @} End Group GRU final output
};
} // namespace forward
......@@ -137,7 +137,7 @@ class gru_stateGrad {
*value_frame_state, act_input);
}
}
#ifndef __NVCC__
#if !defined(__NVCC__) && !defined(__HIPCC___) // @{ Group GRU state grad
#ifndef __AVX__
static const bool avx = false;
#else
......@@ -170,7 +170,7 @@ class gru_stateGrad {
}
}
#endif
#endif
#endif // @} End Group GRU state grad
};
template <typename T>
......@@ -187,7 +187,7 @@ class gru_resetGrad {
*grad_reset_gate =
activation(*grad_reset_gate, *value_reset_gate, act_gate);
}
#ifndef __NVCC__
#if !defined(__NVCC__) && !defined(__HIPCC___) // @{ Group GRU reset grad
#ifndef __AVX__
static const bool avx = false;
#else
......@@ -206,7 +206,7 @@ class gru_resetGrad {
activation(*grad_reset_gate, *value_reset_gate, act_gate);
}
#endif
#endif
#endif // @} End Group GRU reset grad
};
template <typename T>
class gru {
......@@ -230,7 +230,7 @@ class gru {
*value_reset_gate, act_gate);
*grad_reset_output = (*value_reset_gate) * (*grad_frame_state);
}
#ifndef __NVCC__
#if !defined(__NVCC__) && !defined(__HIPCC___) // @{ Group GRU CPU
#ifndef __AVX__
static const bool avx = false;
#else
......@@ -261,7 +261,7 @@ class gru {
*grad_reset_output = _mm256_mul_ps(*value_reset_gate, *grad_frame_state);
}
#endif
#endif
#endif // @} End Group GRU CPU
};
} // namespace backward
......
......@@ -35,7 +35,7 @@ template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenVector = framework::EigenVector<T, MajorType, IndexType>;
#ifndef __NVCC__
#if !defined(__NVCC__) && !defined(__HIPCC___) // @{ Group LSTM CPU
template <class T, class Op>
void naive_lstm_forward_one_sequence(Op op, LstmMetaValue<T> value,
......@@ -467,7 +467,7 @@ void cpu_lstm_backward(const platform::CPUDeviceContext &context, Op op,
}
}
#endif
#endif // @{ End Group LSTM CPU
} // namespace detail
} // namespace math
......
......@@ -50,7 +50,7 @@ class lstm {
*state_atv = activation(*state, active_state);
*output = (*value_og) * (*state_atv);
}
#ifndef __NVCC__
#if !defined(__NVCC__) && !defined(__HIPCC___) // @{ Group LSTM FWD
#ifndef __AVX__ // If not compiled with AVX instructs. Disable AVX by default
static const bool avx = false;
#else
......@@ -87,7 +87,7 @@ class lstm {
*output = _mm256_mul_ps(*value_og, *state_atv);
}
#endif
#endif
#endif // @} End Group LSTM FWD
};
} // namespace forward
......@@ -132,7 +132,7 @@ class lstm {
*checkFGrad = (*grad_fg) * (*prev_state);
*checkOGrad = (*grad_og) * (*state);
}
#ifndef __NVCC__
#if !defined(__NVCC__) && !defined(__HIPCC___) // @{ Group LSTM BWD
#ifndef __AVX__ // If not compiled with AVX instructs. Disable AVX by default
static const bool avx = false;
#else
......@@ -177,7 +177,7 @@ class lstm {
*checkOGrad = _mm256_mul_ps(*grad_og, *state);
}
#endif
#endif
#endif // @} End Group LSTM BWD
};
} // namespace backward
......
......@@ -39,7 +39,7 @@ BufferedReader::BufferedReader(
buffer_size_(buffer_size),
pin_memory_(pin_memory) {
VLOG(1) << "BufferedReader";
#ifdef PADDLE_WITH_CUDA
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
if (platform::is_gpu_place(place_) && !pin_memory) {
int dev_idx = BOOST_GET_CONST(platform::CUDAPlace, place_).device;
compute_stream_ =
......@@ -74,7 +74,7 @@ void BufferedReader::ReadAsync(size_t i) {
return -1UL;
}
#ifdef PADDLE_WITH_CUDA
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) // @{ Group GPU Place
if (platform::is_gpu_place(place_)) {
TensorVec &cuda = cuda_buffer_[i];
if (cuda.empty()) {
......@@ -142,10 +142,17 @@ void BufferedReader::ReadAsync(size_t i) {
// cuda memory immediately without waiting cuda kernel ends
platform::SetDeviceId(
BOOST_GET_CONST(platform::CUDAPlace, place_).device);
#ifdef PADDLE_WITH_HIP
PADDLE_ENFORCE_CUDA_SUCCESS(
hipEventRecord(events_[i].get(), compute_stream_));
PADDLE_ENFORCE_CUDA_SUCCESS(
hipStreamWaitEvent(stream_.get(), events_[i].get(), 0));
#else
PADDLE_ENFORCE_CUDA_SUCCESS(
cudaEventRecord(events_[i].get(), compute_stream_));
PADDLE_ENFORCE_CUDA_SUCCESS(
cudaStreamWaitEvent(stream_.get(), events_[i].get(), 0));
#endif
platform::RecordEvent record_event("BufferedReader:MemoryCopy");
for (size_t i = 0; i < cpu.size(); ++i) {
......@@ -174,14 +181,22 @@ void BufferedReader::ReadAsync(size_t i) {
memory::Copy(BOOST_GET_CONST(platform::CUDAPlace, place_), gpu_ptr,
cuda_pinned_place, cuda_pinned_ptr, size,
stream_.get());
#ifdef PADDLE_WITH_HIP
PADDLE_ENFORCE_CUDA_SUCCESS(hipStreamSynchronize(stream_.get()));
#else
PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamSynchronize(stream_.get()));
#endif
}
cuda[i].set_lod(cpu[i].lod());
}
#ifdef PADDLE_WITH_HIP
PADDLE_ENFORCE_CUDA_SUCCESS(hipStreamSynchronize(stream_.get()));
#else
PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamSynchronize(stream_.get()));
#endif
}
}
#endif
#endif // @} End Group GPU Place
return i;
}));
}
......
......@@ -21,7 +21,7 @@
#include "ThreadPool.h"
#include "paddle/fluid/framework/reader.h"
#ifdef PADDLE_WITH_CUDA
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
#include "paddle/fluid/platform/cuda_resource_pool.h"
#include "paddle/fluid/platform/gpu_info.h"
#endif
......@@ -68,8 +68,8 @@ class BufferedReader : public framework::DecoratedReader {
std::vector<TensorVec> cpu_buffer_;
std::vector<TensorVec> cuda_buffer_;
size_t prev_pos_{-1UL};
#ifdef PADDLE_WITH_CUDA
cudaStream_t compute_stream_;
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
gpuStream_t compute_stream_;
std::shared_ptr<platform::CudaStreamObject> stream_;
std::vector<std::shared_ptr<platform::CudaEventObject>> events_;
#endif
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册