未验证 提交 33703da8 编写于 作者: J jiangcheng 提交者: GitHub

[Cherry-pick] Optimize update_loss_scaling_op(#32554) (#32606)

* optimize update_loss_scaling_op by fused for loop to one kernel, test=develop

* remove useless while loop and optimize variable name, test=develop

* optimize variable name from out_addrs_tensor to out_addrs_mem, test=develop

* optimize variable name for readable by change prefix identifier from t_ to local_
上级 32203c38
...@@ -39,33 +39,36 @@ __global__ void CheckFiniteAndUnscale(const T** xs, const MT* scale, ...@@ -39,33 +39,36 @@ __global__ void CheckFiniteAndUnscale(const T** xs, const MT* scale,
__syncthreads(); __syncthreads();
const int64_t num = s_starts[size]; const int64_t num = s_starts[size];
int pre_xs_index = 0; int xs_index = 0;
bool t_found_inf = false; bool local_found_inf = false;
const MT t_scale = *scale; const MT local_scale = *scale;
for (int64_t idx = tid; idx < num; idx += gridDim.x * blockDim.x) { for (int64_t idx = tid; idx < num; idx += gridDim.x * blockDim.x) {
// get the xs's index of thread // get the "out" index of "id"
int xs_index = pre_xs_index; // For example:
while (idx < s_starts[xs_index]) xs_index++; // idx = 15, starts = [0, 10, 10, 20, 30]
// avoid some tensor's numel is zero // because 10 <= idx < 20 ==>
while (idx >= s_starts[xs_index]) xs_index++; // the idx element locate in the 3rd tensor (notice the 2nd tensor size is
pre_xs_index = xs_index - 1; // 0)
int next_xs_index = xs_index;
while (idx >= s_starts[next_xs_index]) next_xs_index++;
xs_index = next_xs_index - 1;
// get in data and out data // get in data and out data
const T* in = xs[pre_xs_index]; const T* in = xs[xs_index];
T* out = outs[pre_xs_index]; T* out = outs[xs_index];
int64_t in_idx = idx - s_starts[pre_xs_index]; int64_t in_idx = idx - s_starts[xs_index];
// Unscale // Unscale
MT val = static_cast<MT>(in[in_idx]) * t_scale; MT val = static_cast<MT>(in[in_idx]) * local_scale;
T narrow_val = static_cast<T>(val); T narrow_val = static_cast<T>(val);
out[in_idx] = narrow_val; out[in_idx] = narrow_val;
// CheckFinite // CheckFinite
if (!isfinite(narrow_val)) { if (!isfinite(narrow_val)) {
t_found_inf = true; local_found_inf = true;
} }
} }
if (t_found_inf) { if (local_found_inf) {
*found_inf = true; *found_inf = true;
} }
} }
...@@ -94,28 +97,30 @@ class CheckFiniteAndUnscaleGpuKernel : public framework::OpKernel<T> { ...@@ -94,28 +97,30 @@ class CheckFiniteAndUnscaleGpuKernel : public framework::OpKernel<T> {
scale_data, inverse_scale_v, found_inf_data); scale_data, inverse_scale_v, found_inf_data);
size_t xs_size = xs.size(); size_t xs_size = xs.size();
const auto& cpu_place = platform::CPUPlace();
// calculate each tensor's start index and copy to device // calculate each tensor's start index and copy to device
auto h_starts_tensor = auto h_starts_tensor =
memory::Alloc(platform::CPUPlace(), (xs_size + 1) * sizeof(int64_t)); memory::Alloc(cpu_place, (xs_size + 1) * sizeof(int64_t));
int64_t* h_starts = reinterpret_cast<int64_t*>(h_starts_tensor->ptr()); int64_t* h_starts = reinterpret_cast<int64_t*>(h_starts_tensor->ptr());
auto d_starts_tensor = auto d_starts_tensor =
memory::Alloc(dev_ctx, (xs_size + 1) * sizeof(int64_t)); memory::Alloc(dev_ctx, (xs_size + 1) * sizeof(int64_t));
int64_t* d_starts = reinterpret_cast<int64_t*>(d_starts_tensor->ptr()); int64_t* d_starts = reinterpret_cast<int64_t*>(d_starts_tensor->ptr());
// the start index value of each tensor is
// the sum of previous tensor's size. For example:
// xs = [10, 0, 10, 10] ==> starts = [0, 10, 10, 20, 30]
h_starts[0] = 0; h_starts[0] = 0;
for (int i = 1; i <= xs_size; i++) { for (int i = 1; i <= xs_size; i++) {
// the start index value of each tensor is
// the sum of previous tensor's size
h_starts[i] = h_starts[i - 1] + xs[i - 1]->numel(); h_starts[i] = h_starts[i - 1] + xs[i - 1]->numel();
} }
int64_t total_num = h_starts[xs_size]; int64_t total_num = h_starts[xs_size];
memory::Copy(BOOST_GET_CONST(platform::CUDAPlace, dev_ctx.GetPlace()), memory::Copy(BOOST_GET_CONST(platform::CUDAPlace, dev_ctx.GetPlace()),
d_starts, platform::CPUPlace(), h_starts, d_starts, cpu_place, h_starts, (xs_size + 1) * sizeof(int64_t),
(xs_size + 1) * sizeof(int64_t), dev_ctx.stream()); dev_ctx.stream());
// copy each tensor's data address to device // copy each tensor's data address to device
auto h_mem = memory::Alloc(platform::CPUPlace(), 2 * xs_size * sizeof(T*)); auto h_mem = memory::Alloc(cpu_place, 2 * xs_size * sizeof(T*));
const T** h_xs = reinterpret_cast<const T**>(h_mem->ptr()); const T** h_xs = reinterpret_cast<const T**>(h_mem->ptr());
T** h_outs = reinterpret_cast<T**>(h_mem->ptr()) + xs_size; T** h_outs = reinterpret_cast<T**>(h_mem->ptr()) + xs_size;
...@@ -128,16 +133,18 @@ class CheckFiniteAndUnscaleGpuKernel : public framework::OpKernel<T> { ...@@ -128,16 +133,18 @@ class CheckFiniteAndUnscaleGpuKernel : public framework::OpKernel<T> {
h_outs[i] = outs[i]->mutable_data<T>(dev_ctx.GetPlace()); h_outs[i] = outs[i]->mutable_data<T>(dev_ctx.GetPlace());
} }
memory::Copy(BOOST_GET_CONST(platform::CUDAPlace, dev_ctx.GetPlace()), d_xs, memory::Copy(BOOST_GET_CONST(platform::CUDAPlace, dev_ctx.GetPlace()), d_xs,
platform::CPUPlace(), h_xs, 2 * xs_size * sizeof(T*), cpu_place, h_xs, 2 * xs_size * sizeof(T*), dev_ctx.stream());
dev_ctx.stream());
// Launch Kernel // Launch Kernel
int block = 1024; int threads_per_block = std::min(static_cast<int64_t>(1024), total_num);
int block_num = block * 20; // each thread deal with 20 number int elements_per_block =
int grid = (total_num + block_num - 1) / block_num; threads_per_block * 20; // each thread deal with 20 number
int blocks_per_grid =
(total_num + elements_per_block - 1) / elements_per_block;
VLOG(3) << "launch kernel"; VLOG(3) << "launch kernel";
CheckFiniteAndUnscale<T, MPDType><<< CheckFiniteAndUnscale<
grid, block, (xs_size + 1) * sizeof(int64_t), dev_ctx.stream()>>>( T, MPDType><<<blocks_per_grid, threads_per_block,
(xs_size + 1) * sizeof(int64_t), dev_ctx.stream()>>>(
d_xs, inverse_scale_v, xs_size, d_starts, found_inf_data, d_outs); d_xs, inverse_scale_v, xs_size, d_starts, found_inf_data, d_outs);
VLOG(3) << "finish kernel"; VLOG(3) << "finish kernel";
} }
......
...@@ -34,13 +34,39 @@ __global__ void GpuUpdateLossScaling( ...@@ -34,13 +34,39 @@ __global__ void GpuUpdateLossScaling(
} }
template <typename T> template <typename T>
__global__ void FillIf(T* data, const int64_t num, const T value, __global__ void FusedFillIf(T** outs, const size_t xs_size,
const int64_t* starts, const T value,
const bool* has_inf) { const bool* has_inf) {
if (*has_inf) { if (!(*has_inf)) return;
int tid = threadIdx.x + blockIdx.x * blockDim.x;
for (int i = tid; i < num; i += blockDim.x * gridDim.x) { const int tid = threadIdx.x + blockIdx.x * blockDim.x;
data[i] = value;
// copy starts array from global memory to shared memory
extern __shared__ int64_t s_starts[];
for (int i = threadIdx.x; i <= xs_size; i += blockDim.x) {
s_starts[i] = starts[i];
} }
__syncthreads();
const int64_t total_num = s_starts[xs_size];
int out_index = 0;
for (int64_t id = tid; id < total_num; id += blockDim.x * gridDim.x) {
// get the "out" index of "id"
// For example:
// id = 15, starts = [0, 10, 10, 20, 30]
// because 10 <= id < 20 ==>
// the id element locate in the 3rd tensor (notice the 2nd tensor size is 0)
int next_out_index = out_index;
while (id >= s_starts[next_out_index]) next_out_index++;
out_index = next_out_index - 1;
// get data pointer and index
T* out_data = outs[out_index];
int64_t idx = id - s_starts[out_index];
// set value
out_data[idx] = value;
} }
} }
...@@ -68,15 +94,52 @@ class LazyZeros<platform::CUDADeviceContext, T> { ...@@ -68,15 +94,52 @@ class LazyZeros<platform::CUDADeviceContext, T> {
const bool* found_inf_data, const bool* found_inf_data,
const std::vector<const framework::Tensor*>& xs, const std::vector<const framework::Tensor*>& xs,
const std::vector<framework::Tensor*>& outs) const { const std::vector<framework::Tensor*>& outs) const {
for (size_t i = 0; i < xs.size(); ++i) { size_t xs_size = xs.size();
auto* out = outs[i]; const auto& cpu_place = platform::CPUPlace();
T* out_data = out->mutable_data<T>(dev_ctx.GetPlace()); // alloc each tensor's start index and copy to device
int64_t num = out->numel(); auto h_in_starts_mem =
int block = 1024; memory::Alloc(cpu_place, (xs_size + 1) * sizeof(int64_t));
int grid = (block - 1 + num) / block; int64_t* h_starts = reinterpret_cast<int64_t*>(h_in_starts_mem->ptr());
FillIf<<<grid, block, 0, dev_ctx.stream()>>>(
out_data, num, static_cast<T>(0), found_inf_data); auto d_in_starts_mem =
memory::Alloc(dev_ctx, (xs_size + 1) * sizeof(int64_t));
int64_t* d_starts = reinterpret_cast<int64_t*>(d_in_starts_mem->ptr());
// the start index value of each tensor is
// the sum of previous tensor's size. For example:
// outs = [10, 0, 10, 10] ==> starts = [0, 10, 10, 20, 30]
h_starts[0] = 0;
for (int i = 0; i < xs_size; i++) {
h_starts[i + 1] = h_starts[i] + outs[i]->numel();
}
memory::Copy(BOOST_GET_CONST(platform::CUDAPlace, dev_ctx.GetPlace()),
d_starts, cpu_place, h_starts, (xs_size + 1) * sizeof(int64_t),
dev_ctx.stream());
// copy each tensor of "outs" data address array to device
auto h_out_addrs_mem = memory::Alloc(cpu_place, xs_size * sizeof(T*));
T** h_out_addrs = reinterpret_cast<T**>(h_out_addrs_mem->ptr());
auto d_out_addrs_mem = memory::Alloc(dev_ctx, xs_size * sizeof(T*));
T** d_out_addrs = reinterpret_cast<T**>(d_out_addrs_mem->ptr());
for (size_t i = 0; i < xs_size; ++i) {
h_out_addrs[i] = outs[i]->mutable_data<T>(dev_ctx.GetPlace());
} }
memory::Copy(BOOST_GET_CONST(platform::CUDAPlace, dev_ctx.GetPlace()),
d_out_addrs, cpu_place, h_out_addrs, xs_size * sizeof(T*),
dev_ctx.stream());
// launch cuda kernel
int64_t total_num = h_starts[xs_size];
int64_t threads_per_block = std::min(static_cast<int64_t>(1024), total_num);
int64_t elements_per_block =
threads_per_block * 50; // each thread deal with 50 data
int64_t blocks_per_grid =
(total_num + elements_per_block - 1) / elements_per_block;
FusedFillIf<T><<<blocks_per_grid, threads_per_block,
(xs_size + 1) * sizeof(int64_t), dev_ctx.stream()>>>(
d_out_addrs, xs_size, d_starts, static_cast<T>(0), found_inf_data);
} }
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册