提交 7ad13fbf 编写于 作者: Q QI JUN 提交者: GitHub

Merge pull request #4876 from QiJune/sgd_op_sparse_kernel

add sparse update kernel for sgd operator
......@@ -21,7 +21,7 @@ class SGDOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("Param"),
"Input(Param) of SGDOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Grad"),
......@@ -35,15 +35,15 @@ class SGDOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE_EQ(framework::product(lr_dims), 1,
"Learning rate should have 1 element");
auto param_dim = ctx->GetInputDim("Param");
PADDLE_ENFORCE_EQ(param_dim, ctx->GetInputDim("Grad"),
"Two input of SGD Op's dimension must be same.");
// TODO(qijun): check dimensions of Param and Grad at complie
// and run time.
ctx->SetOutputDim("ParamOut", param_dim);
}
};
class SGDOpMaker : public framework::OpProtoAndCheckerMaker {
public:
SGDOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
SGDOpMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("Param", "Input parameter");
AddInput("LearningRate", "Learning rate of SGD");
......@@ -58,6 +58,38 @@ param_out = param - learning_rate * grad;
)DOC");
}
};
template <typename T>
struct SparseSGDFunctor<platform::CPUPlace, T> {
void operator()(const platform::DeviceContext& context,
const framework::SelectedRows& input,
const framework::Tensor& learning_rate,
framework::Tensor* output) {
auto in_height = input.height();
auto out_dims = output->dims();
PADDLE_ENFORCE_EQ(in_height, out_dims[0]);
auto& in_value = input.value();
auto& in_rows = input.rows();
int64_t in_row_numel = in_value.numel() / in_rows.size();
PADDLE_ENFORCE_EQ(in_row_numel, output->numel() / in_height);
auto* in_data = in_value.data<T>();
auto* out_data = output->data<T>();
auto* lr = learning_rate.data<T>();
for (size_t i = 0; i < in_rows.size(); i++) {
for (int64_t j = 0; j < in_row_numel; j++) {
out_data[in_rows[i] * in_row_numel + j] -=
lr[0] * in_data[i * in_row_numel + j];
}
}
}
};
template struct SparseSGDFunctor<platform::CPUPlace, float>;
} // namespace operators
} // namespace paddle
......
......@@ -14,6 +14,66 @@
#define EIGEN_USE_GPU
#include "paddle/operators/sgd_op.h"
#include "paddle/platform/cuda_helper.h"
namespace paddle {
namespace operators {
namespace {
template <typename T>
__global__ void SparseSGDFunctorKernel(const T* selected_rows,
const int64_t* rows,
const T* learning_rate, T* tensor_out,
int64_t row_numel, int block_size) {
const int ty = blockIdx.y;
int tid = threadIdx.x;
selected_rows += ty * row_numel;
tensor_out += rows[ty] * row_numel;
for (int index = tid; index < row_numel; index += block_size) {
// Since index in rows of SelectedRows can be duplicate, we have to use
// Atomic Operation to avoid concurrent write error.
paddle::platform::CudaAtomicAdd(
tensor_out + index, -1.0 * learning_rate[0] * selected_rows[index]);
}
}
} // namespace
template <typename T>
struct SparseSGDFunctor<platform::GPUPlace, T> {
void operator()(const platform::DeviceContext& context,
const framework::SelectedRows& input,
const framework::Tensor& learning_rate,
framework::Tensor* output) {
auto in_height = input.height();
auto out_dims = output->dims();
PADDLE_ENFORCE_EQ(in_height, out_dims[0]);
auto& in_value = input.value();
auto& in_rows = input.rows();
int64_t in_row_numel = in_value.numel() / in_rows.size();
PADDLE_ENFORCE_EQ(in_row_numel, output->numel() / in_height);
auto* in_data = in_value.data<T>();
auto* out_data = output->data<T>();
int block_size = 256;
dim3 threads(block_size, 1);
dim3 grid(1, in_rows.size());
SparseSGDFunctorKernel<
T><<<grid, threads, 0,
reinterpret_cast<const platform::CUDADeviceContext&>(context)
.stream()>>>(in_data, in_rows.data(), learning_rate.data<T>(),
out_data, in_row_numel, block_size);
}
};
template struct SparseSGDFunctor<platform::GPUPlace, float>;
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(sgd,
......
......@@ -15,20 +15,32 @@ limitations under the License. */
#pragma once
#include "paddle/framework/eigen.h"
#include "paddle/framework/op_registry.h"
#include "paddle/framework/selected_rows.h"
namespace paddle {
namespace operators {
template <typename Place, typename T>
struct SparseSGDFunctor {
void operator()(const platform::DeviceContext& context,
const framework::SelectedRows& input,
const framework::Tensor& learning_rate,
framework::Tensor* output);
};
template <typename Place, typename T>
class SGDOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto param = ctx.Input<framework::Tensor>("Param");
auto grad = ctx.Input<framework::Tensor>("Grad");
auto param_out = ctx.Output<framework::Tensor>("ParamOut");
auto learning_rate = ctx.Input<framework::Tensor>("LearningRate");
auto* param = ctx.Input<framework::Tensor>("Param");
auto* param_out = ctx.Output<framework::Tensor>("ParamOut");
auto* learning_rate = ctx.Input<framework::Tensor>("LearningRate");
auto* grad_var = ctx.InputVar("Grad");
// Actually, all tensors are LoDTensor except SelectedRows.
if (grad_var->IsType<framework::LoDTensor>()) {
param_out->mutable_data<T>(ctx.GetPlace());
auto* grad = ctx.Input<framework::Tensor>("Grad");
auto p = framework::EigenVector<T>::Flatten(*param);
auto g = framework::EigenVector<T>::Flatten(*grad);
......@@ -38,8 +50,18 @@ class SGDOpKernel : public framework::OpKernel<T> {
Eigen::DSizes<int, 1> grad_dsize(grad->numel());
o.device(place) = p - lr.broadcast(grad_dsize) * g;
} else if (grad_var->IsType<framework::SelectedRows>()) {
// TODO(qijun): In Sparse SGD operator, in-place update is enforced.
// This manual optimization brings difficulty to track data dependency.
// It's better to find a more elegant solution.
PADDLE_ENFORCE_EQ(param, param_out);
auto* grad = ctx.Input<framework::SelectedRows>("Grad");
SparseSGDFunctor<Place, T> functor;
functor(ctx.device_context(), *grad, *learning_rate, param_out);
} else {
PADDLE_THROW("Unsupported Variable Type of Grad");
}
}
};
} // namespace operators
} // namespace paddle
......@@ -154,7 +154,15 @@ PYBIND11_PLUGIN(core) {
py::return_value_policy::reference)
.def("set_height", &SelectedRows::set_height)
.def("height", &SelectedRows::height)
.def("set_rows", &SelectedRows::set_rows)
.def("set_rows",
[](SelectedRows &self, std::vector<int64_t> rows) {
#ifndef PADDLE_WITH_CUDA
self.set_rows(rows);
#else
Vector<int64_t> new_rows(rows);
self.set_rows(new_rows);
#endif
})
.def("rows", [](SelectedRows &self) {
#ifndef PADDLE_WITH_CUDA
return self.rows();
......@@ -187,6 +195,11 @@ All parameter, weight, gradient are variables in Paddle.
return self.GetMutable<LoDTensor>();
},
py::return_value_policy::reference)
.def("get_selected_rows",
[](Variable &self) -> SelectedRows * {
return self.GetMutable<SelectedRows>();
},
py::return_value_policy::reference)
.def("get_net",
[](Variable &self) -> operators::NetOp * {
return self.GetMutable<operators::NetOp>();
......
......@@ -8,29 +8,30 @@ class TestSelectedRows(unittest.TestCase):
place = core.CPUPlace()
height = 10
rows = [0, 4, 7]
row_numel = 10
selcted_rows = core.SelectedRows(rows, row_numel)
np_array = np.ones((len(rows), height)).astype("float32")
row_numel = 12
selected_rows = core.SelectedRows(rows, height)
np_array = np.ones((len(rows), row_numel)).astype("float32")
np_array[0, 0] = 2.0
np_array[2, 8] = 4.0
tensor = selcted_rows.get_tensor()
tensor = selected_rows.get_tensor()
tensor.set(np_array, place)
# compare rows
self.assertEqual(0, selcted_rows.rows()[0])
self.assertEqual(4, selcted_rows.rows()[1])
self.assertEqual(7, selcted_rows.rows()[2])
self.assertEqual(0, selected_rows.rows()[0])
self.assertEqual(4, selected_rows.rows()[1])
self.assertEqual(7, selected_rows.rows()[2])
# compare height
self.assertEqual(10, selcted_rows.height())
self.assertEqual(10, selected_rows.height())
# compare tensor
self.assertAlmostEqual(2.0,
selcted_rows.get_tensor().get_float_element(0))
selected_rows.get_tensor().get_float_element(0))
self.assertAlmostEqual(1.0,
selcted_rows.get_tensor().get_float_element(1))
selected_rows.get_tensor().get_float_element(1))
self.assertAlmostEqual(
4.0, selcted_rows.get_tensor().get_float_element(2 * row_numel + 8))
4.0,
selected_rows.get_tensor().get_float_element(2 * row_numel + 8))
if __name__ == "__main__":
......
import unittest
import numpy as np
import paddle.v2.framework.core as core
from paddle.v2.framework.op import Operator
from op_test import OpTest
......@@ -17,5 +19,70 @@ class TestSGDOp(OpTest):
self.check_output()
class TestSparseSGDOp(unittest.TestCase):
def check_with_place(self, place):
scope = core.Scope()
# create and initialize Grad Variable
height = 10
rows = [0, 4, 7]
row_numel = 12
grad_selected_rows = scope.var('Grad').get_selected_rows()
grad_selected_rows.set_height(height)
grad_selected_rows.set_rows(rows)
np_array = np.ones((len(rows), row_numel)).astype("float32")
np_array[0, 0] = 2.0
np_array[2, 8] = 4.0
grad_tensor = grad_selected_rows.get_tensor()
grad_tensor.set(np_array, place)
# create and initialize Param Variable
param = scope.var('Param').get_tensor()
param_array = np.full((height, row_numel), 5.0).astype("float32")
param.set(param_array, place)
# create and initialize LeraningRate Variable
lr = scope.var('LearningRate').get_tensor()
lr_array = np.full((1), 2.0).astype("float32")
lr.set(lr_array, place)
# create and run sgd operator
sgd_op = Operator(
"sgd",
Param='Param',
Grad='Grad',
ParamOut='Param',
LearningRate='LearningRate')
ctx = core.DeviceContext.create(place)
sgd_op.run(scope, ctx)
# get and compare result
result_array = np.array(param)
# rows[0] = 0, 5.0 - 2.0 * 2.0
self.assertAlmostEqual(1.0, result_array[rows[0], 0])
# rows[0] = 0, 5.0 - 2.0 * 1.0
self.assertAlmostEqual(3.0, result_array[rows[0], 2])
# 5.0 - 2.0 * 0.0
self.assertAlmostEqual(5.0, result_array[1, 0])
# rows[1] = 4, 5.0 - 2.0 * 1.0
self.assertAlmostEqual(3.0, result_array[rows[1], 10])
# 5.0 - 2.0 * 0.0
self.assertAlmostEqual(5.0, result_array[5, 8])
# rows[2] = 7, 5.0 - 2.0 * 1.0
self.assertAlmostEqual(3.0, result_array[rows[2], 1])
# rows[2] = 7, 5.0 - 2.0 * 4.0
self.assertAlmostEqual(-3.0, result_array[rows[2], 8])
def test_sparse_sgd(self):
places = [core.CPUPlace()]
if core.is_compile_gpu():
places.append(core.GPUPlace(0))
for place in places:
self.check_with_place(place)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册