From 981bc7d4934c70a6f00a33c5a7946352c3ee76cb Mon Sep 17 00:00:00 2001 From: Reza Yazdani <44502768+RezaYazdaniAminabadi@users.noreply.github.com> Date: Tue, 12 Jan 2021 21:35:56 -0800 Subject: [PATCH] Move workspace memory-allocation to PyTorch (#661) * move workspace memory-allocation to PyTorch * refine the code based on the comments * remove unnecessary options * remove bsz from set_seq_len function --- csrc/includes/context.h | 14 ++------- csrc/includes/ds_transformer_cuda.h | 5 ++- csrc/transformer/ds_transformer_cuda.cpp | 40 ++++++++++++++---------- 3 files changed, 31 insertions(+), 28 deletions(-) mode change 100644 => 100755 csrc/includes/ds_transformer_cuda.h diff --git a/csrc/includes/context.h b/csrc/includes/context.h index f8ae6fc4..c2e26cdf 100755 --- a/csrc/includes/context.h +++ b/csrc/includes/context.h @@ -64,17 +64,10 @@ public: return _ctx; } - void GenWorkSpace(size_t size) + void SetWorkSpace(void* workspace) { - if (!_workspace) { - assert(_workspace == nullptr); - cudaMalloc(&_workspace, size); - } else if (_workSpaceSize < size) { - cudaFree(_workspace); - cudaMalloc(&_workspace, size); - } - - _workSpaceSize = size; + if (!workspace) { throw std::runtime_error("Workspace is null."); } + _workspace = workspace; } void* GetWorkSpace() { return _workspace; } @@ -172,6 +165,5 @@ private: void* _workspace; uint64_t _seed; uint64_t _curr_offset; - size_t _workSpaceSize; std::vector> _gemm_algos; }; diff --git a/csrc/includes/ds_transformer_cuda.h b/csrc/includes/ds_transformer_cuda.h old mode 100644 new mode 100755 index 3fac43e4..dbae797a --- a/csrc/includes/ds_transformer_cuda.h +++ b/csrc/includes/ds_transformer_cuda.h @@ -130,10 +130,13 @@ public: inline int GetBatchSize() const { return _batch_size; } inline int GetNumHeads() const { return _heads; } inline int GetSeqLength() const { return _seq_length; } + inline int GetIntermediateSize() const { return _intermediate_size; } - void SetSeqLength(int seq_len, int bsz); + void SetSeqLength(int seq_len); inline int GetHiddenSize() const { return _hidden_size; } void SetTrainingMode(bool training); + inline bool IsTrainingMode() const { return _training; } + inline bool GeluCheckpoint() const { return _gelu_checkpoint; } private: void Initialize(); diff --git a/csrc/transformer/ds_transformer_cuda.cpp b/csrc/transformer/ds_transformer_cuda.cpp index ebd534d0..f22b8a07 100755 --- a/csrc/transformer/ds_transformer_cuda.cpp +++ b/csrc/transformer/ds_transformer_cuda.cpp @@ -33,7 +33,7 @@ size_t get_workspace_size(int maxBatchSize, 2 * (size_t(maxBatchSize) * heads * seq_len * seq_len))); if (gelu_checkpoint) workSpacesize += 2 * (size_t(maxBatchSize) * seq_len * hidden_size); } - return workSpacesize * sizeof(T); + return workSpacesize; // * sizeof(T); } // NOTE: AT_ASSERT has become AT_CHECK on master after 0.4. @@ -123,7 +123,6 @@ BertTransformerLayer::BertTransformerLayer(int layer_id, gemm_algos[4])) { assert(_hidden_size % _heads == 0); - assert(_seq_length <= 1024); Initialize(); } @@ -136,14 +135,6 @@ BertTransformerLayer::~BertTransformerLayer() template void BertTransformerLayer::Initialize() { - Context::Instance().GenWorkSpace(get_workspace_size(_batch_size, - _seq_length, - _hidden_size, - _intermediate_size, - _heads, - _training, - _gelu_checkpoint)); - if (std::is_same::value) cublasSetMathMode(_cublasHandle, CUBLAS_TENSOR_OP_MATH); } @@ -574,7 +565,7 @@ void BertTransformerLayer::SetIntermediateBuffers(uint8_t* attn_prob_dropout_ } template -void BertTransformerLayer::SetSeqLength(int seq_len, int bsz) +void BertTransformerLayer::SetSeqLength(int seq_len) { _seq_length = seq_len; @@ -582,9 +573,6 @@ void BertTransformerLayer::SetSeqLength(int seq_len, int bsz) _attn_prob_dropout.SetDimension(_seq_length); _attn_scores.SetConfig(_seq_length, _seq_length, _hidden_size / _heads); _attn_context.SetConfig(_hidden_size / _heads, _seq_length, _seq_length); - - Context::Instance().GenWorkSpace(get_workspace_size( - bsz, _seq_length, _hidden_size, _intermediate_size, _heads, _training, _gelu_checkpoint)); } template @@ -707,9 +695,19 @@ std::vector ds_transformer_forward(int layer_id, int seq_len = layer->GetSeqLength(); if (input.size(1) != seq_len) { seq_len = input.size(1); - layer->SetSeqLength(seq_len, bsz); + layer->SetSeqLength(seq_len); } + auto workspace = torch::empty({get_workspace_size(bsz, + seq_len, + layer->GetHiddenSize(), + layer->GetIntermediateSize(), + layer->GetNumHeads(), + layer->IsTrainingMode(), + layer->GeluCheckpoint())}, + options); + Context::Instance().SetWorkSpace((T*)workspace.data_ptr()); + auto inp_norm = ((prelayernorm || !normalize_invertible) ? torch::empty_like(input) : output); auto add_res = (normalize_invertible ? inp_norm : torch::empty_like(input)); auto attn_o_inp = torch::empty_like(input); @@ -877,9 +875,19 @@ std::vector ds_transformer_backward(int layer_id, int seq_len = layer->GetSeqLength(); if (g_output.size(1) != seq_len) { seq_len = g_output.size(1); - layer->SetSeqLength(seq_len, bsz); + layer->SetSeqLength(seq_len); } + auto workspace = torch::empty({get_workspace_size(bsz, + seq_len, + layer->GetHiddenSize(), + layer->GetIntermediateSize(), + layer->GetNumHeads(), + layer->IsTrainingMode(), + layer->GeluCheckpoint())}, + grad_output.options()); + Context::Instance().SetWorkSpace((T*)workspace.data_ptr()); + auto grad_input = torch::empty_like(input); auto grad_attn_qkvw = torch::empty_like(attn_qkvw); auto grad_attn_qkvb = torch::empty_like(attn_qkvb); -- GitLab