From 182ce51c6d73d98420aa91d998a328503eac538d Mon Sep 17 00:00:00 2001 From: qijun Date: Tue, 17 Oct 2017 14:48:40 -0700 Subject: [PATCH] add sparse kernel of sgd operator --- paddle/operators/sgd_op.cc | 40 ++++++++++++++++++++++--- paddle/operators/sgd_op.cu | 60 ++++++++++++++++++++++++++++++++++++++ paddle/operators/sgd_op.h | 47 ++++++++++++++++++++--------- 3 files changed, 130 insertions(+), 17 deletions(-) diff --git a/paddle/operators/sgd_op.cc b/paddle/operators/sgd_op.cc index 0f78eeab9b..e26a1c7893 100644 --- a/paddle/operators/sgd_op.cc +++ b/paddle/operators/sgd_op.cc @@ -21,7 +21,7 @@ class SGDOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - void InferShape(framework::InferShapeContext *ctx) const override { + void InferShape(framework::InferShapeContext* ctx) const override { PADDLE_ENFORCE(ctx->HasInput("Param"), "Input(Param) of SGDOp should not be null."); PADDLE_ENFORCE(ctx->HasInput("Grad"), @@ -35,15 +35,15 @@ class SGDOp : public framework::OperatorWithKernel { PADDLE_ENFORCE_EQ(framework::product(lr_dims), 1, "Learning rate should have 1 element"); auto param_dim = ctx->GetInputDim("Param"); - PADDLE_ENFORCE_EQ(param_dim, ctx->GetInputDim("Grad"), - "Two input of SGD Op's dimension must be same."); + // TODO(qijun): check dimensions of Param and Grad at complie + // and run time. ctx->SetOutputDim("ParamOut", param_dim); } }; class SGDOpMaker : public framework::OpProtoAndCheckerMaker { public: - SGDOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) + SGDOpMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { AddInput("Param", "Input parameter"); AddInput("LearningRate", "Learning rate of SGD"); @@ -58,6 +58,38 @@ param_out = param - learning_rate * grad; )DOC"); } }; + +template +struct SparseSGDFunctor { + void operator()(const platform::DeviceContext& ctx, + const framework::SelectedRows& input, + const framework::Tensor& learning_rate, + framework::Tensor* output) { + auto in_height = input.height(); + auto out_dims = output->dims(); + PADDLE_ENFORCE_EQ(in_height, out_dims[0]); + + auto& in_value = input.value(); + auto& in_rows = input.rows(); + + int64_t in_row_numel = in_value.numel() / in_rows.size(); + PADDLE_ENFORCE_EQ(in_row_numel, output->numel() / in_height); + + auto* in_data = in_value.data(); + auto* out_data = output->data(); + auto* lr = learning_rate.data(); + + for (size_t i = 0; i < in_rows.size(); i++) { + for (int64_t j = 0; j < in_row_numel; j++) { + out_data[in_rows[i] * in_row_numel + j] -= + lr[0] * in_data[i * in_row_numel + j]; + } + } + } +}; + +template struct SparseSGDFunctor; + } // namespace operators } // namespace paddle diff --git a/paddle/operators/sgd_op.cu b/paddle/operators/sgd_op.cu index f5ba6d3c29..5c28314141 100644 --- a/paddle/operators/sgd_op.cu +++ b/paddle/operators/sgd_op.cu @@ -14,6 +14,66 @@ #define EIGEN_USE_GPU #include "paddle/operators/sgd_op.h" +#include "paddle/platform/cuda_helper.h" + +namespace paddle { +namespace operators { + +namespace { +template +__global__ void SparseSGDFunctorKernel(const T* selected_rows, + const int64_t* rows, + const T* learning_rate, T* tensor_out, + int64_t row_numel, int block_size) { + const int ty = blockIdx.y; + int tid = threadIdx.x; + + selected_rows += ty * row_numel; + tensor_out += rows[ty] * row_numel; + + for (int index = tid; index < row_numel; index += block_size) { + // Since index in rows of SelectedRows can be duplicate, we have to use + // Atomic Operation to avoid concurrent write error. + paddle::platform::CudaAtomicSub(tensor_out + index, + learning_rate[0] * selected_rows[index]); + } +} +} // namespace + +template +struct SparseSGDFunctor { + void operator()(const platform::DeviceContext& ctx, + const framework::SelectedRows& input, + const framework::Tensor& learning_rate, + framework::Tensor* output) { + auto in_height = input.height(); + auto out_dims = output->dims(); + PADDLE_ENFORCE_EQ(in_height, out_dims[0]); + + auto& in_value = input.value(); + auto& in_rows = input.rows(); + + int64_t in_row_numel = in_value.numel() / in_rows.size(); + PADDLE_ENFORCE_EQ(in_row_numel, output->numel() / in_height); + + auto* in_data = in_value.data(); + auto* out_data = output->data(); + + int block_size = 256; + dim3 threads(block_size, 1); + dim3 grid(1, in_rows.size()); + SparseSGDFunctorKernel< + T><<(context) + .stream()>>>(in_data, in_rows.data(), learning_rate.data(), + out_data, in_row_numel, block_size); + } +}; + +template struct SparseSGDFunctor; + +} // namespace operators +} // namespace paddle namespace ops = paddle::operators; REGISTER_OP_GPU_KERNEL(sgd, diff --git a/paddle/operators/sgd_op.h b/paddle/operators/sgd_op.h index 26f4012f25..a872d7f749 100644 --- a/paddle/operators/sgd_op.h +++ b/paddle/operators/sgd_op.h @@ -15,31 +15,52 @@ limitations under the License. */ #pragma once #include "paddle/framework/eigen.h" #include "paddle/framework/op_registry.h" +#include "paddle/framework/selected_rows.h" namespace paddle { namespace operators { +template +struct SparseSGDFunctor { + void operator()(const platform::DeviceContext& ctx, + const framework::SelectedRows& input, + const framework::Tensor& learning_rate, + framework::Tensor* output); +}; + template class SGDOpKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { - auto param = ctx.Input("Param"); - auto grad = ctx.Input("Grad"); - auto param_out = ctx.Output("ParamOut"); - auto learning_rate = ctx.Input("LearningRate"); + auto* param = ctx.Input("Param"); + auto* param_out = ctx.Output("ParamOut"); + auto* learning_rate = ctx.Input("LearningRate"); - param_out->mutable_data(ctx.GetPlace()); + auto* grad_var = ctx.InputVar("Grad"); + if (grad_var->IsType()) { + param_out->mutable_data(ctx.GetPlace()); + auto* grad = ctx.Input("Grad"); - auto p = framework::EigenVector::Flatten(*param); - auto g = framework::EigenVector::Flatten(*grad); - auto o = framework::EigenVector::Flatten(*param_out); - auto lr = framework::EigenVector::Flatten(*learning_rate); - auto place = ctx.GetEigenDevice(); + auto p = framework::EigenVector::Flatten(*param); + auto g = framework::EigenVector::Flatten(*grad); + auto o = framework::EigenVector::Flatten(*param_out); + auto lr = framework::EigenVector::Flatten(*learning_rate); + auto place = ctx.GetEigenDevice(); - Eigen::DSizes grad_dsize(grad->numel()); - o.device(place) = p - lr.broadcast(grad_dsize) * g; + Eigen::DSizes grad_dsize(grad->numel()); + o.device(place) = p - lr.broadcast(grad_dsize) * g; + } else if (grad_var->IsType()) { + // TODO(qijun): In Sparse SGD operator, in-place update is enforced. + // This manual optimization brings difficulty to track data dependency. + // It's better to find a more elegant solution. + PADDLE_ENFORCE_EQ(param, param_out); + auto* grad = ctx.Input("Grad"); + SparseSGDFunctor functor; + functor(ctx.device_context(), *grad, *learning_rate, param_out); + } else { + PADDLE_THROW("Unsupported Variable Type of Grad"); + } } }; - } // namespace operators } // namespace paddle -- GitLab