未验证 提交 43a3af86 编写于 作者: C chengduo 提交者: GitHub

refine sgd_op (#13626)

test=develop
上级 adae0a3b
...@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#define EIGEN_USE_GPU #include <algorithm>
#include "paddle/fluid/operators/sgd_op.h" #include "paddle/fluid/operators/sgd_op.h"
#include "paddle/fluid/platform/cuda_primitives.h" #include "paddle/fluid/platform/cuda_primitives.h"
...@@ -33,22 +33,21 @@ __global__ void SGDKernel(const T* g, const T* p, const T* learning_rate, ...@@ -33,22 +33,21 @@ __global__ void SGDKernel(const T* g, const T* p, const T* learning_rate,
} }
} }
template <typename T, int block_size> template <typename T>
__global__ void SparseSGDFunctorKernel(const T* selected_rows, __global__ void SparseSGDFunctorKernel(const T* selected_rows,
const int64_t* rows, const int64_t* rows,
const T* learning_rate, T* tensor_out, const T* learning_rate, T* tensor_out,
int64_t row_numel) { int64_t row_numel, int64_t limit) {
const int ty = blockIdx.y; for (int64_t i = blockIdx.x; i < limit; i += gridDim.x) {
int tid = threadIdx.x; const T* selected_rows_ptr = selected_rows + i * row_numel;
T* tensor_out_ptr = tensor_out + rows[i] * row_numel;
selected_rows += ty * row_numel; for (int64_t index = threadIdx.x; index < row_numel; index += blockDim.x) {
tensor_out += rows[ty] * row_numel;
for (int index = tid; index < row_numel; index += block_size) {
// Since index in rows of SelectedRows can be duplicate, we have to use // Since index in rows of SelectedRows can be duplicate, we have to use
// Atomic Operation to avoid concurrent write error. // Atomic Operation to avoid concurrent write error.
paddle::platform::CudaAtomicAdd( paddle::platform::CudaAtomicAdd(
tensor_out + index, -1.0 * learning_rate[0] * selected_rows[index]); tensor_out_ptr + index,
-1.0 * learning_rate[0] * selected_rows_ptr[index]);
}
} }
} }
} // namespace } // namespace
...@@ -97,13 +96,15 @@ class SGDOpCUDAKernel : public framework::OpKernel<T> { ...@@ -97,13 +96,15 @@ class SGDOpCUDAKernel : public framework::OpKernel<T> {
auto* in_data = in_value.data<T>(); auto* in_data = in_value.data<T>();
auto* out_data = param_out->data<T>(); auto* out_data = param_out->data<T>();
const int block_size = 256; const int kThreadsPerBlock = 256;
dim3 threads(block_size, 1); int thread_x = kThreadsPerBlock;
dim3 grid(1, in_rows.size()); int max_threads = ctx.cuda_device_context().GetMaxPhysicalThreadCount();
SparseSGDFunctorKernel< int max_blocks = std::max(max_threads / kThreadsPerBlock, 1);
T, 256><<<grid, threads, 0, ctx.cuda_device_context().stream()>>>(
SparseSGDFunctorKernel<<<max_blocks, thread_x, 0,
ctx.cuda_device_context().stream()>>>(
in_data, in_rows.CUDAData(ctx.GetPlace()), learning_rate->data<T>(), in_data, in_rows.CUDAData(ctx.GetPlace()), learning_rate->data<T>(),
out_data, in_row_numel); out_data, in_row_numel, in_rows.size());
} else { } else {
PADDLE_THROW("Unsupported Variable Type of Grad"); PADDLE_THROW("Unsupported Variable Type of Grad");
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册