提交 26aac8d8 编写于 作者: P phlrain

update

上级 5b5941c7
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
#include <memory> #include <memory>
#include "paddle/fluid/operators/optimizers/momentum_op.h" #include "paddle/fluid/operators/optimizers/momentum_op.h"
#include "paddle/fluid/operators/optimizers/sgd_op.h" #include "paddle/phi/kernels/sgd_kernel.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -26,8 +26,7 @@ template <typename DeviceContext, typename T> ...@@ -26,8 +26,7 @@ template <typename DeviceContext, typename T>
class DGCMomentumKernel : public framework::OpKernel<T> { class DGCMomentumKernel : public framework::OpKernel<T> {
public: public:
DGCMomentumKernel() DGCMomentumKernel()
: _momentum_op_kernel(new MomentumOpKernel<DeviceContext, T>()), : _momentum_op_kernel(new MomentumOpKernel<DeviceContext, T>()) {}
_sgd_op_kernel(new SGDOpKernel<DeviceContext, T>()) {}
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
auto rampup_begin_step = context.Attr<float>("rampup_begin_step"); auto rampup_begin_step = context.Attr<float>("rampup_begin_step");
...@@ -67,12 +66,68 @@ class DGCMomentumKernel : public framework::OpKernel<T> { ...@@ -67,12 +66,68 @@ class DGCMomentumKernel : public framework::OpKernel<T> {
} }
VLOG(10) << " so use sgd optimizer"; VLOG(10) << " so use sgd optimizer";
return _sgd_op_kernel->Compute(context);
const auto* param_var = context.InputVar("Param");
const auto* grad_var = context.InputVar("Grad");
auto* learning_rate = context.Input<framework::Tensor>("LearningRate");
bool multi_precision = context.Attr<bool>("multi_precision");
if (param_var->IsType<framework::LoDTensor>()) {
auto* param = context.Input<framework::Tensor>("Param");
auto* param_out = context.Output<framework::Tensor>("ParamOut");
auto* master_param_out =
context.Output<framework::Tensor>("MasterParamOut");
paddle::optional<const framework::Tensor&> master_param_opt =
paddle::none;
if (multi_precision) {
auto* master_param = context.Input<framework::Tensor>("MasterParam");
master_param_opt = *master_param;
}
if (grad_var->IsType<framework::Tensor>()) {
// sgd_dense
auto* grad = context.Input<framework::Tensor>("Grad");
phi::SGDDenseKernel<T>(
static_cast<const typename framework::ConvertToPhiContext<
DeviceContext>::TYPE&>(dev_ctx),
*param, *learning_rate, *grad, master_param_opt, multi_precision,
param_out, master_param_out);
} else {
// sgd dense param sparse grad
auto* grad = context.Input<phi::SelectedRows>("Grad");
phi::SGDDenseParamSparseGradKernel<T>(
static_cast<const typename framework::ConvertToPhiContext<
DeviceContext>::TYPE&>(dev_ctx),
*param, *learning_rate, *grad, master_param_opt, multi_precision,
param_out, master_param_out);
}
} else if (param_var->IsType<phi::SelectedRows>() &&
grad_var->IsType<phi::SelectedRows>() &&
platform::is_cpu_place(context.GetPlace())) {
// sgd sparse param sparse grad
auto* param = context.Input<phi::SelectedRows>("Param");
auto* param_out = context.Output<phi::SelectedRows>("ParamOut");
auto* master_param_out =
context.Output<phi::SelectedRows>("MasterParamOut");
paddle::optional<const phi::SelectedRows&> master_param_opt =
paddle::none;
if (multi_precision) {
auto* master_param = context.Input<phi::SelectedRows>("MasterParam");
master_param_opt = *master_param;
}
auto* grad = context.Input<phi::SelectedRows>("Grad");
phi::SGDSparseParamSparseGradKernel<T>(
static_cast<const typename framework::ConvertToPhiContext<
DeviceContext>::TYPE&>(dev_ctx),
*param, *learning_rate, *grad, master_param_opt, multi_precision,
param_out, master_param_out);
} else {
PADDLE_THROW("gdc not support yet");
}
} }
private: private:
std::unique_ptr<MomentumOpKernel<DeviceContext, T>> _momentum_op_kernel; std::unique_ptr<MomentumOpKernel<DeviceContext, T>> _momentum_op_kernel;
std::unique_ptr<SGDOpKernel<DeviceContext, T>> _sgd_op_kernel;
}; };
} // namespace operators } // namespace operators
......
...@@ -166,8 +166,3 @@ REGISTER_OPERATOR( ...@@ -166,8 +166,3 @@ REGISTER_OPERATOR(
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>, paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>, paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>,
ops::SGDOpInferVarType); ops::SGDOpInferVarType);
REGISTER_OP_CPU_KERNEL(
sgd, ops::SGDOpKernel<paddle::platform::CPUDeviceContext, float>,
ops::SGDOpKernel<paddle::platform::CPUDeviceContext,
paddle::platform::bfloat16>,
ops::SGDOpKernel<paddle::platform::CPUDeviceContext, double>);
...@@ -166,10 +166,3 @@ class SGDOpKernel<platform::CUDADeviceContext, T> ...@@ -166,10 +166,3 @@ class SGDOpKernel<platform::CUDADeviceContext, T>
}; };
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(
sgd, ops::SGDOpKernel<paddle::platform::CUDADeviceContext, float>,
ops::SGDOpKernel<paddle::platform::CUDADeviceContext, double>,
ops::SGDOpKernel<paddle::platform::CUDADeviceContext, plat::float16>);
...@@ -221,6 +221,7 @@ struct KernelImpl<Return (*)(DevCtx, Args...), kernel_fn> { ...@@ -221,6 +221,7 @@ struct KernelImpl<Return (*)(DevCtx, Args...), kernel_fn> {
PT_SPECIALIZE_KernelCallHelper_FOR_INPUT(DenseTensor); PT_SPECIALIZE_KernelCallHelper_FOR_INPUT(DenseTensor);
PT_SPECIALIZE_KernelCallHelper_FOR_OPTIONAL_INPUT(DenseTensor); PT_SPECIALIZE_KernelCallHelper_FOR_OPTIONAL_INPUT(DenseTensor);
PT_SPECIALIZE_KernelCallHelper_FOR_OPTIONAL_INPUT(SelectedRows);
PT_SPECIALIZE_KernelCallHelper_FOR_MULTI_INPUT(DenseTensor); PT_SPECIALIZE_KernelCallHelper_FOR_MULTI_INPUT(DenseTensor);
#ifndef PADDLE_WITH_CUSTOM_KERNEL #ifndef PADDLE_WITH_CUSTOM_KERNEL
PT_SPECIALIZE_KernelCallHelper_FOR_INPUT(SelectedRows); PT_SPECIALIZE_KernelCallHelper_FOR_INPUT(SelectedRows);
......
...@@ -14,6 +14,8 @@ ...@@ -14,6 +14,8 @@
#include "paddle/phi/kernels/sgd_kernel.h" #include "paddle/phi/kernels/sgd_kernel.h"
#include "paddle/fluid/operators/jit/kernels.h" #include "paddle/fluid/operators/jit/kernels.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/eigen/common.h" #include "paddle/phi/kernels/funcs/eigen/common.h"
namespace phi { namespace phi {
...@@ -112,40 +114,42 @@ void sgd_dense_param_sparse_grad_impl<phi::dtype::bfloat16>( ...@@ -112,40 +114,42 @@ void sgd_dense_param_sparse_grad_impl<phi::dtype::bfloat16>(
} }
template <typename T, typename Context> template <typename T, typename Context>
void SGDKernel(const Context& dev_ctx, void SGDDenseKernel(const Context& dev_ctx,
const DenseTensor& param, const DenseTensor& param,
const DenseTensor& learning_rate, const DenseTensor& learning_rate,
const DenseTensor& grad, const DenseTensor& grad,
const DenseTensor& master_param, paddle::optional<const DenseTensor&> master_param,
bool multi_precision, bool multi_precision,
DenseTensor* param_out, DenseTensor* param_out,
DenseTensor* master_param_out) { DenseTensor* master_param_out) {
dev_ctx.template Alloc<T>(param_out); dev_ctx.template Alloc<T>(param_out);
sgd_dense_param_dense_grad_impl<T>(param, learning_rate, grad, param_out); sgd_dense_param_dense_grad_impl<T>(param, learning_rate, grad, param_out);
} }
template <typename T, typename Context> template <typename T, typename Context>
void SGDKernel(const Context& dev_ctx, void SGDDenseParamSparseGradKernel(
const DenseTensor& param, const Context& dev_ctx,
const DenseTensor& learning_rate, const DenseTensor& param,
const SelectedRows& grad, const DenseTensor& learning_rate,
const DenseTensor& master_param, const SelectedRows& grad,
bool multi_precision, paddle::optional<const DenseTensor&> master_param,
DenseTensor* param_out, bool multi_precision,
DenseTensor* master_param_out) { DenseTensor* param_out,
DenseTensor* master_param_out) {
dev_ctx.template Alloc<T>(param_out); dev_ctx.template Alloc<T>(param_out);
sgd_dense_param_sparse_grad_impl<T>(param, learning_rate, grad, param_out); sgd_dense_param_sparse_grad_impl<T>(param, learning_rate, grad, param_out);
} }
template <typename T, typename Context> template <typename T, typename Context>
void SGDKernel(const Context& dev_ctx, void SGDSparseParamSparseGradKernel(
const SelectedRows& param, const Context& dev_ctx,
const DenseTensor& learning_rate, const SelectedRows& param,
const SelectedRows& grad, const DenseTensor& learning_rate,
const SelectedRows& master_param, const SelectedRows& grad,
bool multi_precision, paddle::optional<const SelectedRows&> master_param,
SelectedRows* param_out, bool multi_precision,
SelectedRows* master_param_out) { SelectedRows* param_out,
SelectedRows* master_param_out) {
// for distributed training, a sparse var may be empty, // for distributed training, a sparse var may be empty,
// just skip updating. // just skip updating.
if (grad.rows().size() == 0) { if (grad.rows().size() == 0) {
...@@ -183,3 +187,27 @@ void SGDKernel(const Context& dev_ctx, ...@@ -183,3 +187,27 @@ void SGDKernel(const Context& dev_ctx,
} }
} // namespace phi } // namespace phi
PD_REGISTER_KERNEL(sgd,
CPU,
ALL_LAYOUT,
phi::SGDDenseKernel,
phi::dtype::bfloat16,
float,
double) {}
PD_REGISTER_KERNEL(sgd_dense_param_sparse_grad,
CPU,
ALL_LAYOUT,
phi::SGDDenseParamSparseGradKernel,
phi::dtype::bfloat16,
float,
double) {}
PD_REGISTER_KERNEL(sgd_sparse_param_sparse_grad,
CPU,
ALL_LAYOUT,
phi::SGDSparseParamSparseGradKernel,
phi::dtype::bfloat16,
float,
double) {}
...@@ -18,6 +18,9 @@ ...@@ -18,6 +18,9 @@
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h" #include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
#include "paddle/phi/backends/gpu/gpu_helper.h" #include "paddle/phi/backends/gpu/gpu_helper.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
namespace phi { namespace phi {
template <typename T, typename MT> template <typename T, typename MT>
...@@ -61,14 +64,15 @@ __global__ void SparseSGDFunctorKernel(const T* selected_rows, ...@@ -61,14 +64,15 @@ __global__ void SparseSGDFunctorKernel(const T* selected_rows,
} }
template <typename T, typename Context> template <typename T, typename Context>
void SGDKernel(const Context& dev_ctx, void SGDDenseKernel(const Context& dev_ctx,
const DenseTensor& param, const DenseTensor& param,
const DenseTensor& learning_rate, const DenseTensor& learning_rate,
const DenseTensor& grad, const DenseTensor& grad,
const DenseTensor& master_param, paddle::optional<const DenseTensor&> master_param,
bool multi_precision, bool multi_precision,
DenseTensor* param_out, DenseTensor* param_out,
DenseTensor* master_param_out) { DenseTensor* master_param_out) {
LOG(ERROR) << "run here";
using MPDType = typename paddle::operators::details::MPTypeTrait<T>::Type; using MPDType = typename paddle::operators::details::MPTypeTrait<T>::Type;
// do check here // do check here
// if (multi_precision) { // if (multi_precision) {
...@@ -77,7 +81,7 @@ void SGDKernel(const Context& dev_ctx, ...@@ -77,7 +81,7 @@ void SGDKernel(const Context& dev_ctx,
// } // }
const MPDType* master_in_data = const MPDType* master_in_data =
multi_precision ? master_param.data<MPDType>() : nullptr; multi_precision ? master_param->data<MPDType>() : nullptr;
MPDType* master_out_data = MPDType* master_out_data =
multi_precision multi_precision
? master_param_out->mutable_data<MPDType>(dev_ctx.GetPlace()) ? master_param_out->mutable_data<MPDType>(dev_ctx.GetPlace())
...@@ -91,20 +95,21 @@ void SGDKernel(const Context& dev_ctx, ...@@ -91,20 +95,21 @@ void SGDKernel(const Context& dev_ctx,
grad.data<T>(), grad.data<T>(),
learning_rate.data<T>(), learning_rate.data<T>(),
param.numel(), param.numel(),
param_out->mutable_data<T>(ctx.GetPlace()), param_out->mutable_data<T>(dev_ctx.GetPlace()),
master_in_data, master_in_data,
master_out_data); master_out_data);
} }
template <typename T, typename Context> template <typename T, typename Context>
void SGDKernel(const Context& dev_ctx, void SGDDenseParamSparseGradKernel(
const DenseTensor& param, const Context& dev_ctx,
const DenseTensor& learning_rate, const DenseTensor& param,
const SelectedRows& grad, const DenseTensor& learning_rate,
const DenseTensor& master_param, const SelectedRows& grad,
bool multi_precision, paddle::optional<const DenseTensor&> master_param,
DenseTensor* param_out, bool multi_precision,
DenseTensor* master_param_out) { DenseTensor* param_out,
DenseTensor* master_param_out) {
using MPDType = typename paddle::operators::details::MPTypeTrait<T>::Type; using MPDType = typename paddle::operators::details::MPTypeTrait<T>::Type;
// do some check here // do some check here
// if (multi_precision) { // if (multi_precision) {
...@@ -113,7 +118,7 @@ void SGDKernel(const Context& dev_ctx, ...@@ -113,7 +118,7 @@ void SGDKernel(const Context& dev_ctx,
// } // }
const MPDType* master_in_data = const MPDType* master_in_data =
multi_precision ? master_param.data<MPDType>() : nullptr; multi_precision ? master_param->data<MPDType>() : nullptr;
MPDType* master_out_data = MPDType* master_out_data =
multi_precision multi_precision
? master_param_out->mutable_data<MPDType>(dev_ctx.GetPlace()) ? master_param_out->mutable_data<MPDType>(dev_ctx.GetPlace())
...@@ -155,7 +160,7 @@ void SGDKernel(const Context& dev_ctx, ...@@ -155,7 +160,7 @@ void SGDKernel(const Context& dev_ctx,
int max_threads = dev_ctx.GetMaxPhysicalThreadCount(); int max_threads = dev_ctx.GetMaxPhysicalThreadCount();
int max_blocks = std::max(max_threads / kThreadsPerBlock, 1); int max_blocks = std::max(max_threads / kThreadsPerBlock, 1);
paddle::framework::MixVector<int64_t> mixv_in_rows(&in_rows); paddle::framework::MixVector<int64_t> mixv_in_rows(&in_rows);
SparseSGDFunctorKernel<<<max_blocks, thread_x, 0, dev_ctx..stream()>>>( SparseSGDFunctorKernel<<<max_blocks, thread_x, 0, dev_ctx.stream()>>>(
in_data, in_data,
mixv_in_rows.CUDAData(dev_ctx.GetPlace()), mixv_in_rows.CUDAData(dev_ctx.GetPlace()),
learning_rate.data<T>(), learning_rate.data<T>(),
...@@ -164,4 +169,41 @@ void SGDKernel(const Context& dev_ctx, ...@@ -164,4 +169,41 @@ void SGDKernel(const Context& dev_ctx,
in_rows.size()); in_rows.size());
} }
template <typename T, typename Context>
void SGDSparseParamSparseGradKernel(
const Context& dev_ctx,
const SelectedRows& param,
const DenseTensor& learning_rate,
const SelectedRows& grad,
paddle::optional<const SelectedRows&> master_param,
bool multi_precision,
SelectedRows* param_out,
SelectedRows* master_param_out) {
PADDLE_THROW("not impl");
}
} // namespace phi } // namespace phi
PD_REGISTER_KERNEL(sgd,
GPU,
ALL_LAYOUT,
phi::SGDDenseKernel,
phi::dtype::float16,
float,
double) {}
PD_REGISTER_KERNEL(sgd_dense_param_sparse_grad,
GPU,
ALL_LAYOUT,
phi::SGDDenseParamSparseGradKernel,
phi::dtype::float16,
float,
double) {}
PD_REGISTER_KERNEL(sgd_sparse_param_sparse_grad,
GPU,
ALL_LAYOUT,
phi::SGDSparseParamSparseGradKernel,
phi::dtype::float16,
float,
double) {}
...@@ -20,33 +20,35 @@ ...@@ -20,33 +20,35 @@
namespace phi { namespace phi {
template <typename T, typename Context> template <typename T, typename Context>
void SGDKernel(const Context& dev_ctx, void SGDDenseKernel(const Context& dev_ctx,
const DenseTensor& param, const DenseTensor& param,
const DenseTensor& learning_rate, const DenseTensor& learning_rate,
const DenseTensor& grad, const DenseTensor& grad,
const DenseTensor& master_param, paddle::optional<const DenseTensor&> master_param,
bool multi_precision, bool multi_precision,
DenseTensor* param_out, DenseTensor* param_out,
DenseTensor* master_param_out); DenseTensor* master_param_out);
template <typename T, typename Context> template <typename T, typename Context>
void SGDKernel(const Context& dev_ctx, void SGDDenseParamSparseGradKernel(
const DenseTensor& param, const Context& dev_ctx,
const DenseTensor& learning_rate, const DenseTensor& param,
const SelectedRows& grad, const DenseTensor& learning_rate,
const DenseTensor& master_param, const SelectedRows& grad,
bool multi_precision, paddle::optional<const DenseTensor&> master_param,
DenseTensor* param_out, bool multi_precision,
DenseTensor* master_param_out); DenseTensor* param_out,
DenseTensor* master_param_out);
template <typename T, typename Context> template <typename T, typename Context>
void SGDKernel(const Context& dev_ctx, void SGDSparseParamSparseGradKernel(
const SelectedRows& param, const Context& dev_ctx,
const DenseTensor& learning_rate, const SelectedRows& param,
const SelectedRows& grad, const DenseTensor& learning_rate,
const SelectedRows& master_param, const SelectedRows& grad,
bool multi_precision, paddle::optional<const SelectedRows&> master_param,
SelectedRows* param_out, bool multi_precision,
SelectedRows* master_param_out); SelectedRows* param_out,
SelectedRows* master_param_out);
} // namespace phi } // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/core/compat/op_utils.h"
namespace phi {
KernelSignature SGDOpArgumentMapping(const ArgumentMappingContext& ctx) {
LOG(ERROR) << "11";
if (ctx.IsDenseTensorInput("Grad")) {
LOG(ERROR) << "dense";
return KernelSignature("sgd",
{"Param", "LearningRate", "Grad", "MasterParam"},
{"multi_precision"},
{"ParamOut", "MasterParamOut"});
} else if (ctx.IsSelectedRowsInput("Grad")) {
if (ctx.IsDenseTensorInput("Param")) {
return KernelSignature("sgd_dense_param_sparse_grad",
{"Param", "LearningRate", "Grad", "MasterParam"},
{"multi_precision"},
{"ParamOut", "MasterParamOut"});
} else {
return KernelSignature("sgd_sparse_param_sparse_grad",
{"Param", "LearningRate", "Grad", "MasterParam"},
{"multi_precision"},
{"ParamOut", "MasterParamOut"});
}
}
return KernelSignature("unregistered", {}, {}, {});
}
} // namespace phi
PD_REGISTER_ARG_MAPPING_FN(sgd, phi::SGDOpArgumentMapping);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册