From 02fda711c8d5f64441af00f29b49e8cdca8b66f1 Mon Sep 17 00:00:00 2001 From: chengduoZH Date: Sat, 23 Dec 2017 16:29:11 +0800 Subject: [PATCH] refine sgd-op --- paddle/operators/sgd_op.cc | 36 +------------ paddle/operators/sgd_op.cu | 100 +++++++++++++++++++++++++------------ paddle/operators/sgd_op.h | 41 ++++++++------- 3 files changed, 94 insertions(+), 83 deletions(-) diff --git a/paddle/operators/sgd_op.cc b/paddle/operators/sgd_op.cc index fb4b43e47..a11c9624c 100644 --- a/paddle/operators/sgd_op.cc +++ b/paddle/operators/sgd_op.cc @@ -61,43 +61,9 @@ $$param\_out = param - learning\_rate * grad$$ } }; -template -struct SparseSGDFunctor { - void operator()(const platform::CPUDeviceContext& context, - const framework::SelectedRows& input, - const framework::Tensor& learning_rate, - framework::Tensor* output) { - auto in_height = input.height(); - auto out_dims = output->dims(); - PADDLE_ENFORCE_EQ(in_height, out_dims[0]); - - auto& in_value = input.value(); - auto& in_rows = input.rows(); - - int64_t in_row_numel = in_value.numel() / in_rows.size(); - PADDLE_ENFORCE_EQ(in_row_numel, output->numel() / in_height); - - auto* in_data = in_value.data(); - auto* out_data = output->data(); - auto* lr = learning_rate.data(); - - for (size_t i = 0; i < in_rows.size(); i++) { - for (int64_t j = 0; j < in_row_numel; j++) { - out_data[in_rows[i] * in_row_numel + j] -= - lr[0] * in_data[i * in_row_numel + j]; - } - } - } -}; - -template struct SparseSGDFunctor; -template struct SparseSGDFunctor; - } // namespace operators } // namespace paddle namespace ops = paddle::operators; REGISTER_OP_WITHOUT_GRADIENT(sgd, ops::SGDOp, ops::SGDOpMaker); -REGISTER_OP_CPU_KERNEL( - sgd, ops::SGDOpKernel, - ops::SGDOpKernel); +REGISTER_OP_CPU_KERNEL(sgd, ops::SGDOpKernel, ops::SGDOpKernel); diff --git a/paddle/operators/sgd_op.cu b/paddle/operators/sgd_op.cu index a3c0db7e5..e68ce92bb 100644 --- a/paddle/operators/sgd_op.cu +++ b/paddle/operators/sgd_op.cu @@ -20,6 +20,19 @@ namespace paddle { namespace operators { namespace { + +template +__global__ void SGDKernel(const T* g, const T* p, const T* learning_rate, + const int num, T* p_out) { + T lr = learning_rate[0]; + int grid_size = blockDim.x * gridDim.x; + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < num; i += grid_size) { + T g_data = g[i]; + T p_data = p[i]; + p_out[i] = p_data - lr * g_data; + } +} + template __global__ void SparseSGDFunctorKernel(const T* selected_rows, const int64_t* rows, @@ -41,40 +54,65 @@ __global__ void SparseSGDFunctorKernel(const T* selected_rows, } // namespace template -struct SparseSGDFunctor { - void operator()(const platform::CUDADeviceContext& context, - const framework::SelectedRows& input, - const framework::Tensor& learning_rate, - framework::Tensor* output) { - auto in_height = input.height(); - auto out_dims = output->dims(); - PADDLE_ENFORCE_EQ(in_height, out_dims[0]); - - auto& in_value = input.value(); - auto& in_rows = input.rows(); - - int64_t in_row_numel = in_value.numel() / in_rows.size(); - PADDLE_ENFORCE_EQ(in_row_numel, output->numel() / in_height); - - auto* in_data = in_value.data(); - auto* out_data = output->data(); - - const int block_size = 256; - dim3 threads(block_size, 1); - dim3 grid(1, in_rows.size()); - SparseSGDFunctorKernel<<>>( - in_data, in_rows.data(), learning_rate.data(), out_data, - in_row_numel); +class SGDOpCUDAKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* param = ctx.Input("Param"); + auto* param_out = ctx.Output("ParamOut"); + auto* learning_rate = ctx.Input("LearningRate"); + + auto* grad_var = ctx.InputVar("Grad"); + // Actually, all tensors are LoDTensor except SelectedRows. + if (grad_var->IsType()) { + param_out->mutable_data(ctx.GetPlace()); + auto* grad = ctx.Input("Grad"); + auto* grad_data = grad->data(); + auto* param_data = param->data(); + auto* param_out_data = param_out->data(); + + int block = 512; + int grid = (param->numel() + block - 1) / block; + + SGDKernel<<>>( + grad_data, param_data, learning_rate->data(), param->numel(), + param_out_data); + + } else if (grad_var->IsType()) { + // TODO(qijun): In Sparse SGD operator, in-place update is enforced. + // This manual optimization brings difficulty to track data dependency. + // It's better to find a more elegant solution. + PADDLE_ENFORCE_EQ(param, param_out); + auto* grad = ctx.Input("Grad"); + + auto in_height = grad->height(); + auto out_dims = param_out->dims(); + PADDLE_ENFORCE_EQ(in_height, out_dims[0]); + + auto& in_value = grad->value(); + auto& in_rows = grad->rows(); + + int64_t in_row_numel = in_value.numel() / in_rows.size(); + PADDLE_ENFORCE_EQ(in_row_numel, param_out->numel() / in_height); + + auto* in_data = in_value.data(); + auto* out_data = param_out->data(); + + const int block_size = 256; + dim3 threads(block_size, 1); + dim3 grid(1, in_rows.size()); + SparseSGDFunctorKernel< + T, 256><<>>( + in_data, in_rows.data(), learning_rate->data(), out_data, + in_row_numel); + + } else { + PADDLE_THROW("Unsupported Variable Type of Grad"); + } } }; - -template struct SparseSGDFunctor; -template struct SparseSGDFunctor; - } // namespace operators } // namespace paddle namespace ops = paddle::operators; -REGISTER_OP_CUDA_KERNEL( - sgd, ops::SGDOpKernel, - ops::SGDOpKernel); +REGISTER_OP_CUDA_KERNEL(sgd, ops::SGDOpCUDAKernel, + ops::SGDOpCUDAKernel); diff --git a/paddle/operators/sgd_op.h b/paddle/operators/sgd_op.h index c920025a9..a6c544591 100644 --- a/paddle/operators/sgd_op.h +++ b/paddle/operators/sgd_op.h @@ -20,15 +20,7 @@ limitations under the License. */ namespace paddle { namespace operators { -template -struct SparseSGDFunctor { - void operator()(const DeviceContext& context, - const framework::SelectedRows& input, - const framework::Tensor& learning_rate, - framework::Tensor* output); -}; - -template +template class SGDOpKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { @@ -45,21 +37,36 @@ class SGDOpKernel : public framework::OpKernel { auto p = framework::EigenVector::Flatten(*param); auto g = framework::EigenVector::Flatten(*grad); auto o = framework::EigenVector::Flatten(*param_out); - auto lr = framework::EigenVector::Flatten(*learning_rate); - auto& place = - *ctx.template device_context().eigen_device(); + auto* lr = learning_rate->data(); - Eigen::DSizes grad_dsize(grad->numel()); - o.device(place) = p - lr.broadcast(grad_dsize) * g; + o = p - lr[0] * g; } else if (grad_var->IsType()) { // TODO(qijun): In Sparse SGD operator, in-place update is enforced. // This manual optimization brings difficulty to track data dependency. // It's better to find a more elegant solution. PADDLE_ENFORCE_EQ(param, param_out); auto* grad = ctx.Input("Grad"); - SparseSGDFunctor functor; - functor(ctx.template device_context(), *grad, - *learning_rate, param_out); + + auto in_height = grad->height(); + auto out_dims = param_out->dims(); + PADDLE_ENFORCE_EQ(in_height, out_dims[0]); + + auto& in_value = grad->value(); + auto& in_rows = grad->rows(); + + int64_t in_row_numel = in_value.numel() / in_rows.size(); + PADDLE_ENFORCE_EQ(in_row_numel, param_out->numel() / in_height); + + auto* in_data = in_value.data(); + auto* out_data = param_out->data(); + auto* lr = learning_rate->data(); + + for (size_t i = 0; i < in_rows.size(); i++) { + for (int64_t j = 0; j < in_row_numel; j++) { + out_data[in_rows[i] * in_row_numel + j] -= + lr[0] * in_data[i * in_row_numel + j]; + } + } } else { PADDLE_THROW("Unsupported Variable Type of Grad"); } -- GitLab