提交 182ce51c 编写于 作者: Q qijun

add sparse kernel of sgd operator

上级 3ae9aa93
......@@ -21,7 +21,7 @@ class SGDOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("Param"),
"Input(Param) of SGDOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Grad"),
......@@ -35,15 +35,15 @@ class SGDOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE_EQ(framework::product(lr_dims), 1,
"Learning rate should have 1 element");
auto param_dim = ctx->GetInputDim("Param");
PADDLE_ENFORCE_EQ(param_dim, ctx->GetInputDim("Grad"),
"Two input of SGD Op's dimension must be same.");
// TODO(qijun): check dimensions of Param and Grad at complie
// and run time.
ctx->SetOutputDim("ParamOut", param_dim);
}
};
class SGDOpMaker : public framework::OpProtoAndCheckerMaker {
public:
SGDOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
SGDOpMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("Param", "Input parameter");
AddInput("LearningRate", "Learning rate of SGD");
......@@ -58,6 +58,38 @@ param_out = param - learning_rate * grad;
)DOC");
}
};
template <typename T>
struct SparseSGDFunctor<platform::CPUPlace, T> {
void operator()(const platform::DeviceContext& ctx,
const framework::SelectedRows& input,
const framework::Tensor& learning_rate,
framework::Tensor* output) {
auto in_height = input.height();
auto out_dims = output->dims();
PADDLE_ENFORCE_EQ(in_height, out_dims[0]);
auto& in_value = input.value();
auto& in_rows = input.rows();
int64_t in_row_numel = in_value.numel() / in_rows.size();
PADDLE_ENFORCE_EQ(in_row_numel, output->numel() / in_height);
auto* in_data = in_value.data<T>();
auto* out_data = output->data<T>();
auto* lr = learning_rate.data<T>();
for (size_t i = 0; i < in_rows.size(); i++) {
for (int64_t j = 0; j < in_row_numel; j++) {
out_data[in_rows[i] * in_row_numel + j] -=
lr[0] * in_data[i * in_row_numel + j];
}
}
}
};
template struct SparseSGDFunctor<platform::CPUPlace, float>;
} // namespace operators
} // namespace paddle
......
......@@ -14,6 +14,66 @@
#define EIGEN_USE_GPU
#include "paddle/operators/sgd_op.h"
#include "paddle/platform/cuda_helper.h"
namespace paddle {
namespace operators {
namespace {
template <typename T>
__global__ void SparseSGDFunctorKernel(const T* selected_rows,
const int64_t* rows,
const T* learning_rate, T* tensor_out,
int64_t row_numel, int block_size) {
const int ty = blockIdx.y;
int tid = threadIdx.x;
selected_rows += ty * row_numel;
tensor_out += rows[ty] * row_numel;
for (int index = tid; index < row_numel; index += block_size) {
// Since index in rows of SelectedRows can be duplicate, we have to use
// Atomic Operation to avoid concurrent write error.
paddle::platform::CudaAtomicSub(tensor_out + index,
learning_rate[0] * selected_rows[index]);
}
}
} // namespace
template <typename T>
struct SparseSGDFunctor<platform::GPUPlace, T> {
void operator()(const platform::DeviceContext& ctx,
const framework::SelectedRows& input,
const framework::Tensor& learning_rate,
framework::Tensor* output) {
auto in_height = input.height();
auto out_dims = output->dims();
PADDLE_ENFORCE_EQ(in_height, out_dims[0]);
auto& in_value = input.value();
auto& in_rows = input.rows();
int64_t in_row_numel = in_value.numel() / in_rows.size();
PADDLE_ENFORCE_EQ(in_row_numel, output->numel() / in_height);
auto* in_data = in_value.data<T>();
auto* out_data = output->data<T>();
int block_size = 256;
dim3 threads(block_size, 1);
dim3 grid(1, in_rows.size());
SparseSGDFunctorKernel<
T><<<grid, threads, 0,
reinterpret_cast<const platform::CUDADeviceContext&>(context)
.stream()>>>(in_data, in_rows.data(), learning_rate.data<T>(),
out_data, in_row_numel, block_size);
}
};
template struct SparseSGDFunctor<platform::GPUPlace, float>;
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(sgd,
......
......@@ -15,31 +15,52 @@ limitations under the License. */
#pragma once
#include "paddle/framework/eigen.h"
#include "paddle/framework/op_registry.h"
#include "paddle/framework/selected_rows.h"
namespace paddle {
namespace operators {
template <typename Place, typename T>
struct SparseSGDFunctor {
void operator()(const platform::DeviceContext& ctx,
const framework::SelectedRows& input,
const framework::Tensor& learning_rate,
framework::Tensor* output);
};
template <typename Place, typename T>
class SGDOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto param = ctx.Input<framework::Tensor>("Param");
auto grad = ctx.Input<framework::Tensor>("Grad");
auto param_out = ctx.Output<framework::Tensor>("ParamOut");
auto learning_rate = ctx.Input<framework::Tensor>("LearningRate");
auto* param = ctx.Input<framework::Tensor>("Param");
auto* param_out = ctx.Output<framework::Tensor>("ParamOut");
auto* learning_rate = ctx.Input<framework::Tensor>("LearningRate");
param_out->mutable_data<T>(ctx.GetPlace());
auto* grad_var = ctx.InputVar("Grad");
if (grad_var->IsType<framework::Tensor>()) {
param_out->mutable_data<T>(ctx.GetPlace());
auto* grad = ctx.Input<framework::Tensor>("Grad");
auto p = framework::EigenVector<T>::Flatten(*param);
auto g = framework::EigenVector<T>::Flatten(*grad);
auto o = framework::EigenVector<T>::Flatten(*param_out);
auto lr = framework::EigenVector<T>::Flatten(*learning_rate);
auto place = ctx.GetEigenDevice<Place>();
auto p = framework::EigenVector<T>::Flatten(*param);
auto g = framework::EigenVector<T>::Flatten(*grad);
auto o = framework::EigenVector<T>::Flatten(*param_out);
auto lr = framework::EigenVector<T>::Flatten(*learning_rate);
auto place = ctx.GetEigenDevice<Place>();
Eigen::DSizes<int, 1> grad_dsize(grad->numel());
o.device(place) = p - lr.broadcast(grad_dsize) * g;
Eigen::DSizes<int, 1> grad_dsize(grad->numel());
o.device(place) = p - lr.broadcast(grad_dsize) * g;
} else if (grad_var->IsType<framework::SelectedRows>()) {
// TODO(qijun): In Sparse SGD operator, in-place update is enforced.
// This manual optimization brings difficulty to track data dependency.
// It's better to find a more elegant solution.
PADDLE_ENFORCE_EQ(param, param_out);
auto* grad = ctx.Input<framework::SelectedRows>("Grad");
SparseSGDFunctor<Place, T> functor;
functor(ctx.device_context(), *grad, *learning_rate, param_out);
} else {
PADDLE_THROW("Unsupported Variable Type of Grad");
}
}
};
} // namespace operators
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册