提交 02fda711 编写于 作者: C chengduoZH

refine sgd-op

上级 bb58a474
...@@ -61,43 +61,9 @@ $$param\_out = param - learning\_rate * grad$$ ...@@ -61,43 +61,9 @@ $$param\_out = param - learning\_rate * grad$$
} }
}; };
template <typename T>
struct SparseSGDFunctor<platform::CPUDeviceContext, T> {
void operator()(const platform::CPUDeviceContext& context,
const framework::SelectedRows& input,
const framework::Tensor& learning_rate,
framework::Tensor* output) {
auto in_height = input.height();
auto out_dims = output->dims();
PADDLE_ENFORCE_EQ(in_height, out_dims[0]);
auto& in_value = input.value();
auto& in_rows = input.rows();
int64_t in_row_numel = in_value.numel() / in_rows.size();
PADDLE_ENFORCE_EQ(in_row_numel, output->numel() / in_height);
auto* in_data = in_value.data<T>();
auto* out_data = output->data<T>();
auto* lr = learning_rate.data<T>();
for (size_t i = 0; i < in_rows.size(); i++) {
for (int64_t j = 0; j < in_row_numel; j++) {
out_data[in_rows[i] * in_row_numel + j] -=
lr[0] * in_data[i * in_row_numel + j];
}
}
}
};
template struct SparseSGDFunctor<platform::CPUDeviceContext, float>;
template struct SparseSGDFunctor<platform::CPUDeviceContext, double>;
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_WITHOUT_GRADIENT(sgd, ops::SGDOp, ops::SGDOpMaker); REGISTER_OP_WITHOUT_GRADIENT(sgd, ops::SGDOp, ops::SGDOpMaker);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(sgd, ops::SGDOpKernel<float>, ops::SGDOpKernel<double>);
sgd, ops::SGDOpKernel<paddle::platform::CPUDeviceContext, float>,
ops::SGDOpKernel<paddle::platform::CPUDeviceContext, double>);
...@@ -20,6 +20,19 @@ namespace paddle { ...@@ -20,6 +20,19 @@ namespace paddle {
namespace operators { namespace operators {
namespace { namespace {
template <typename T>
__global__ void SGDKernel(const T* g, const T* p, const T* learning_rate,
const int num, T* p_out) {
T lr = learning_rate[0];
int grid_size = blockDim.x * gridDim.x;
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < num; i += grid_size) {
T g_data = g[i];
T p_data = p[i];
p_out[i] = p_data - lr * g_data;
}
}
template <typename T, int block_size> template <typename T, int block_size>
__global__ void SparseSGDFunctorKernel(const T* selected_rows, __global__ void SparseSGDFunctorKernel(const T* selected_rows,
const int64_t* rows, const int64_t* rows,
...@@ -41,40 +54,65 @@ __global__ void SparseSGDFunctorKernel(const T* selected_rows, ...@@ -41,40 +54,65 @@ __global__ void SparseSGDFunctorKernel(const T* selected_rows,
} // namespace } // namespace
template <typename T> template <typename T>
struct SparseSGDFunctor<platform::CUDADeviceContext, T> { class SGDOpCUDAKernel : public framework::OpKernel<T> {
void operator()(const platform::CUDADeviceContext& context, public:
const framework::SelectedRows& input, void Compute(const framework::ExecutionContext& ctx) const override {
const framework::Tensor& learning_rate, auto* param = ctx.Input<framework::Tensor>("Param");
framework::Tensor* output) { auto* param_out = ctx.Output<framework::Tensor>("ParamOut");
auto in_height = input.height(); auto* learning_rate = ctx.Input<framework::Tensor>("LearningRate");
auto out_dims = output->dims();
auto* grad_var = ctx.InputVar("Grad");
// Actually, all tensors are LoDTensor except SelectedRows.
if (grad_var->IsType<framework::LoDTensor>()) {
param_out->mutable_data<T>(ctx.GetPlace());
auto* grad = ctx.Input<framework::Tensor>("Grad");
auto* grad_data = grad->data<T>();
auto* param_data = param->data<T>();
auto* param_out_data = param_out->data<T>();
int block = 512;
int grid = (param->numel() + block - 1) / block;
SGDKernel<T><<<grid, block, 0, ctx.cuda_device_context().stream()>>>(
grad_data, param_data, learning_rate->data<T>(), param->numel(),
param_out_data);
} else if (grad_var->IsType<framework::SelectedRows>()) {
// TODO(qijun): In Sparse SGD operator, in-place update is enforced.
// This manual optimization brings difficulty to track data dependency.
// It's better to find a more elegant solution.
PADDLE_ENFORCE_EQ(param, param_out);
auto* grad = ctx.Input<framework::SelectedRows>("Grad");
auto in_height = grad->height();
auto out_dims = param_out->dims();
PADDLE_ENFORCE_EQ(in_height, out_dims[0]); PADDLE_ENFORCE_EQ(in_height, out_dims[0]);
auto& in_value = input.value(); auto& in_value = grad->value();
auto& in_rows = input.rows(); auto& in_rows = grad->rows();
int64_t in_row_numel = in_value.numel() / in_rows.size(); int64_t in_row_numel = in_value.numel() / in_rows.size();
PADDLE_ENFORCE_EQ(in_row_numel, output->numel() / in_height); PADDLE_ENFORCE_EQ(in_row_numel, param_out->numel() / in_height);
auto* in_data = in_value.data<T>(); auto* in_data = in_value.data<T>();
auto* out_data = output->data<T>(); auto* out_data = param_out->data<T>();
const int block_size = 256; const int block_size = 256;
dim3 threads(block_size, 1); dim3 threads(block_size, 1);
dim3 grid(1, in_rows.size()); dim3 grid(1, in_rows.size());
SparseSGDFunctorKernel<T, 256><<<grid, threads, 0, context.stream()>>>( SparseSGDFunctorKernel<
in_data, in_rows.data(), learning_rate.data<T>(), out_data, T, 256><<<grid, threads, 0, ctx.cuda_device_context().stream()>>>(
in_data, in_rows.data(), learning_rate->data<T>(), out_data,
in_row_numel); in_row_numel);
} else {
PADDLE_THROW("Unsupported Variable Type of Grad");
}
} }
}; };
template struct SparseSGDFunctor<platform::CUDADeviceContext, float>;
template struct SparseSGDFunctor<platform::CUDADeviceContext, double>;
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(sgd, ops::SGDOpCUDAKernel<float>,
sgd, ops::SGDOpKernel<paddle::platform::CUDADeviceContext, float>, ops::SGDOpCUDAKernel<double>);
ops::SGDOpKernel<paddle::platform::CUDADeviceContext, double>);
...@@ -20,15 +20,7 @@ limitations under the License. */ ...@@ -20,15 +20,7 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace operators { namespace operators {
template <typename DeviceContext, typename T> template <typename T>
struct SparseSGDFunctor {
void operator()(const DeviceContext& context,
const framework::SelectedRows& input,
const framework::Tensor& learning_rate,
framework::Tensor* output);
};
template <typename DeviceContext, typename T>
class SGDOpKernel : public framework::OpKernel<T> { class SGDOpKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
...@@ -45,21 +37,36 @@ class SGDOpKernel : public framework::OpKernel<T> { ...@@ -45,21 +37,36 @@ class SGDOpKernel : public framework::OpKernel<T> {
auto p = framework::EigenVector<T>::Flatten(*param); auto p = framework::EigenVector<T>::Flatten(*param);
auto g = framework::EigenVector<T>::Flatten(*grad); auto g = framework::EigenVector<T>::Flatten(*grad);
auto o = framework::EigenVector<T>::Flatten(*param_out); auto o = framework::EigenVector<T>::Flatten(*param_out);
auto lr = framework::EigenVector<T>::Flatten(*learning_rate); auto* lr = learning_rate->data<T>();
auto& place =
*ctx.template device_context<DeviceContext>().eigen_device();
Eigen::DSizes<int, 1> grad_dsize(grad->numel()); o = p - lr[0] * g;
o.device(place) = p - lr.broadcast(grad_dsize) * g;
} else if (grad_var->IsType<framework::SelectedRows>()) { } else if (grad_var->IsType<framework::SelectedRows>()) {
// TODO(qijun): In Sparse SGD operator, in-place update is enforced. // TODO(qijun): In Sparse SGD operator, in-place update is enforced.
// This manual optimization brings difficulty to track data dependency. // This manual optimization brings difficulty to track data dependency.
// It's better to find a more elegant solution. // It's better to find a more elegant solution.
PADDLE_ENFORCE_EQ(param, param_out); PADDLE_ENFORCE_EQ(param, param_out);
auto* grad = ctx.Input<framework::SelectedRows>("Grad"); auto* grad = ctx.Input<framework::SelectedRows>("Grad");
SparseSGDFunctor<DeviceContext, T> functor;
functor(ctx.template device_context<DeviceContext>(), *grad, auto in_height = grad->height();
*learning_rate, param_out); auto out_dims = param_out->dims();
PADDLE_ENFORCE_EQ(in_height, out_dims[0]);
auto& in_value = grad->value();
auto& in_rows = grad->rows();
int64_t in_row_numel = in_value.numel() / in_rows.size();
PADDLE_ENFORCE_EQ(in_row_numel, param_out->numel() / in_height);
auto* in_data = in_value.data<T>();
auto* out_data = param_out->data<T>();
auto* lr = learning_rate->data<T>();
for (size_t i = 0; i < in_rows.size(); i++) {
for (int64_t j = 0; j < in_row_numel; j++) {
out_data[in_rows[i] * in_row_numel + j] -=
lr[0] * in_data[i * in_row_numel + j];
}
}
} else { } else {
PADDLE_THROW("Unsupported Variable Type of Grad"); PADDLE_THROW("Unsupported Variable Type of Grad");
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册