未验证 提交 981bc7d4 编写于 作者: R Reza Yazdani 提交者: GitHub

Move workspace memory-allocation to PyTorch (#661)

* move workspace memory-allocation to PyTorch

* refine the code based on the comments

* remove unnecessary options

* remove bsz from set_seq_len function
上级 e2fbe4d2
......@@ -64,17 +64,10 @@ public:
return _ctx;
}
void GenWorkSpace(size_t size)
void SetWorkSpace(void* workspace)
{
if (!_workspace) {
assert(_workspace == nullptr);
cudaMalloc(&_workspace, size);
} else if (_workSpaceSize < size) {
cudaFree(_workspace);
cudaMalloc(&_workspace, size);
}
_workSpaceSize = size;
if (!workspace) { throw std::runtime_error("Workspace is null."); }
_workspace = workspace;
}
void* GetWorkSpace() { return _workspace; }
......@@ -172,6 +165,5 @@ private:
void* _workspace;
uint64_t _seed;
uint64_t _curr_offset;
size_t _workSpaceSize;
std::vector<std::array<int, 3>> _gemm_algos;
};
......@@ -130,10 +130,13 @@ public:
inline int GetBatchSize() const { return _batch_size; }
inline int GetNumHeads() const { return _heads; }
inline int GetSeqLength() const { return _seq_length; }
inline int GetIntermediateSize() const { return _intermediate_size; }
void SetSeqLength(int seq_len, int bsz);
void SetSeqLength(int seq_len);
inline int GetHiddenSize() const { return _hidden_size; }
void SetTrainingMode(bool training);
inline bool IsTrainingMode() const { return _training; }
inline bool GeluCheckpoint() const { return _gelu_checkpoint; }
private:
void Initialize();
......
......@@ -33,7 +33,7 @@ size_t get_workspace_size(int maxBatchSize,
2 * (size_t(maxBatchSize) * heads * seq_len * seq_len)));
if (gelu_checkpoint) workSpacesize += 2 * (size_t(maxBatchSize) * seq_len * hidden_size);
}
return workSpacesize * sizeof(T);
return workSpacesize; // * sizeof(T);
}
// NOTE: AT_ASSERT has become AT_CHECK on master after 0.4.
......@@ -123,7 +123,6 @@ BertTransformerLayer<T>::BertTransformerLayer(int layer_id,
gemm_algos[4]))
{
assert(_hidden_size % _heads == 0);
assert(_seq_length <= 1024);
Initialize();
}
......@@ -136,14 +135,6 @@ BertTransformerLayer<T>::~BertTransformerLayer()
template <typename T>
void BertTransformerLayer<T>::Initialize()
{
Context::Instance().GenWorkSpace(get_workspace_size<T>(_batch_size,
_seq_length,
_hidden_size,
_intermediate_size,
_heads,
_training,
_gelu_checkpoint));
if (std::is_same<T, __half>::value) cublasSetMathMode(_cublasHandle, CUBLAS_TENSOR_OP_MATH);
}
......@@ -574,7 +565,7 @@ void BertTransformerLayer<T>::SetIntermediateBuffers(uint8_t* attn_prob_dropout_
}
template <typename T>
void BertTransformerLayer<T>::SetSeqLength(int seq_len, int bsz)
void BertTransformerLayer<T>::SetSeqLength(int seq_len)
{
_seq_length = seq_len;
......@@ -582,9 +573,6 @@ void BertTransformerLayer<T>::SetSeqLength(int seq_len, int bsz)
_attn_prob_dropout.SetDimension(_seq_length);
_attn_scores.SetConfig(_seq_length, _seq_length, _hidden_size / _heads);
_attn_context.SetConfig(_hidden_size / _heads, _seq_length, _seq_length);
Context::Instance().GenWorkSpace(get_workspace_size<T>(
bsz, _seq_length, _hidden_size, _intermediate_size, _heads, _training, _gelu_checkpoint));
}
template <typename T>
......@@ -707,9 +695,19 @@ std::vector<torch::Tensor> ds_transformer_forward(int layer_id,
int seq_len = layer->GetSeqLength();
if (input.size(1) != seq_len) {
seq_len = input.size(1);
layer->SetSeqLength(seq_len, bsz);
layer->SetSeqLength(seq_len);
}
auto workspace = torch::empty({get_workspace_size<T>(bsz,
seq_len,
layer->GetHiddenSize(),
layer->GetIntermediateSize(),
layer->GetNumHeads(),
layer->IsTrainingMode(),
layer->GeluCheckpoint())},
options);
Context::Instance().SetWorkSpace((T*)workspace.data_ptr());
auto inp_norm = ((prelayernorm || !normalize_invertible) ? torch::empty_like(input) : output);
auto add_res = (normalize_invertible ? inp_norm : torch::empty_like(input));
auto attn_o_inp = torch::empty_like(input);
......@@ -877,9 +875,19 @@ std::vector<torch::Tensor> ds_transformer_backward(int layer_id,
int seq_len = layer->GetSeqLength();
if (g_output.size(1) != seq_len) {
seq_len = g_output.size(1);
layer->SetSeqLength(seq_len, bsz);
layer->SetSeqLength(seq_len);
}
auto workspace = torch::empty({get_workspace_size<T>(bsz,
seq_len,
layer->GetHiddenSize(),
layer->GetIntermediateSize(),
layer->GetNumHeads(),
layer->IsTrainingMode(),
layer->GeluCheckpoint())},
grad_output.options());
Context::Instance().SetWorkSpace((T*)workspace.data_ptr());
auto grad_input = torch::empty_like(input);
auto grad_attn_qkvw = torch::empty_like(attn_qkvw);
auto grad_attn_qkvb = torch::empty_like(attn_qkvb);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册