From 72dde4abde2793eaeb297f12da6e621ed9808e77 Mon Sep 17 00:00:00 2001 From: zhaoyuchen2018 <45989343+zhaoyuchen2018@users.noreply.github.com> Date: Thu, 27 Feb 2020 17:37:21 +0800 Subject: [PATCH] Refine adam op to improve performance, test=develop (#22346) * Refine adam op, test=develop * Fuse kernels together to reduce cpu time. * Refine paddle enforce, test=develop * Remove some comments, test=develop * Refine code,test=develop * Refine cuda kernel, test=develop * Refine code according to comments, test=develop --- paddle/fluid/operators/coalesce_tensor_op.cc | 12 +- paddle/fluid/operators/optimizers/adam_op.cc | 15 +- paddle/fluid/operators/optimizers/adam_op.cu | 285 +++++++++++++++- paddle/fluid/operators/optimizers/adam_op.h | 334 +++++++------------ paddle/fluid/pybind/tensor_py.h | 15 +- python/paddle/fluid/optimizer.py | 11 +- 6 files changed, 452 insertions(+), 220 deletions(-) diff --git a/paddle/fluid/operators/coalesce_tensor_op.cc b/paddle/fluid/operators/coalesce_tensor_op.cc index 94a446a1c43..5b7bcde21a9 100644 --- a/paddle/fluid/operators/coalesce_tensor_op.cc +++ b/paddle/fluid/operators/coalesce_tensor_op.cc @@ -145,7 +145,8 @@ class CoalesceTensorOpKernel : public framework::OpKernel { auto size = lod_tensors[i]->numel(); PADDLE_ENFORCE_GT(size, 0); ss << "input(" << var_names[i] << ") dim:(" << lod_tensors[i]->dims() - << "), "; + << ") " + << " addres:" << lod_tensors[i]->data() << ", "; *numel += platform::Alignment(static_cast(size) * size_of_dtype, place) / size_of_dtype; @@ -160,6 +161,15 @@ class CoalesceTensorOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; void InferShape(framework::InferShapeContext *ctx) const override {} + + protected: + framework::OpKernelType GetKernelTypeForVar( + const std::string &var_name, const framework::Tensor &tensor, + const framework::OpKernelType &expected_kernel_type) const override { + return framework::OpKernelType(expected_kernel_type.data_type_, + expected_kernel_type.place_, + tensor.layout()); + } }; class CoalesceTensorOpMaker : public framework::OpProtoAndCheckerMaker { diff --git a/paddle/fluid/operators/optimizers/adam_op.cc b/paddle/fluid/operators/optimizers/adam_op.cc index 8c9e4ca90c3..86bfd9232a4 100644 --- a/paddle/fluid/operators/optimizers/adam_op.cc +++ b/paddle/fluid/operators/optimizers/adam_op.cc @@ -19,7 +19,7 @@ namespace operators { using Tensor = framework::Tensor; -void AdamOp::InferShape(framework::InferShapeContext* ctx) const { +void AdamOp::InferShape(framework::InferShapeContext *ctx) const { PADDLE_ENFORCE_EQ( ctx->HasInput("Param"), true, platform::errors::NotFound("Input(Param) of AdamOp should not be null.")); @@ -126,11 +126,22 @@ void AdamOp::InferShape(framework::InferShapeContext* ctx) const { } framework::OpKernelType AdamOp::GetExpectedKernelType( - const framework::ExecutionContext& ctx) const { + const framework::ExecutionContext &ctx) const { auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "Param"); return framework::OpKernelType(input_data_type, ctx.GetPlace()); } +framework::OpKernelType AdamOp::GetKernelTypeForVar( + const std::string &var_name, const framework::Tensor &tensor, + const framework::OpKernelType &expected_kernel_type) const { + if (var_name == "Beta1Pow" || var_name == "Beta2Pow") { + return expected_kernel_type; + } else { + return framework::OpKernelType(expected_kernel_type.data_type_, + tensor.place(), tensor.layout()); + } +} + class AdamOpMaker : public framework::OpProtoAndCheckerMaker { public: void Make() override { diff --git a/paddle/fluid/operators/optimizers/adam_op.cu b/paddle/fluid/operators/optimizers/adam_op.cu index 4eb2db717d4..b130ffe6464 100644 --- a/paddle/fluid/operators/optimizers/adam_op.cu +++ b/paddle/fluid/operators/optimizers/adam_op.cu @@ -13,7 +13,286 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/optimizers/adam_op.h" +namespace paddle { +namespace operators { + +template +__global__ void AdamKernelREG(T beta1, T beta2, T epsilon, T beta1_pow_, + T beta2_pow_, const T* moment1, T* moment1_out, + const T* moment2, T* moment2_out, const T* lr_, + const T* grad, const T* param, T* param_out, + int ndim) { + T lr = *lr_; + T beta1_pow = beta1_pow_; + T beta2_pow = beta2_pow_; + + lr *= + sqrt(static_cast(1.0) - beta2_pow) / (static_cast(1.0) - beta1_pow); + + int id = blockIdx.x * blockDim.x + threadIdx.x; + + for (; id < ndim; id += gridDim.x * blockDim.x) { + T p = param[id]; + T g = grad[id]; + T mom1 = moment1[id]; + T mom2 = moment2[id]; + mom1 = beta1 * mom1 + (static_cast(1.0) - beta1) * g; + mom2 = beta2 * mom2 + (static_cast(1.0) - beta2) * g * g; + p -= lr * (mom1 / (sqrt(mom2) + epsilon)); + + moment1_out[id] = mom1; + moment2_out[id] = mom2; + param_out[id] = p; + } +} + +template +__global__ void AdamKernelMEM(T beta1, T beta2, T epsilon, const T* beta1_pow_, + const T* beta2_pow_, const T* moment1, + T* moment1_out, const T* moment2, T* moment2_out, + const T* lr_, const T* grad, const T* param, + T* param_out, int ndim) { + T lr = *lr_; + T beta1_pow = *beta1_pow_; + T beta2_pow = *beta2_pow_; + + lr *= + sqrt(static_cast(1.0) - beta2_pow) / (static_cast(1.0) - beta1_pow); + + int id = blockIdx.x * blockDim.x + threadIdx.x; + + for (; id < ndim; id += gridDim.x * blockDim.x) { + T p = param[id]; + T g = grad[id]; + T mom1 = moment1[id]; + T mom2 = moment2[id]; + mom1 = beta1 * mom1 + (static_cast(1.0) - beta1) * g; + mom2 = beta2 * mom2 + (static_cast(1.0) - beta2) * g * g; + p -= lr * (mom1 / (sqrt(mom2) + epsilon)); + + moment1_out[id] = mom1; + moment2_out[id] = mom2; + param_out[id] = p; + } +} +template +__global__ void UpdateBetaPow(T beta1, T beta2, const T* beta1_pow_, + const T* beta2_pow_, T* beta1_pow_out, + T* beta2_pow_out) { + *beta1_pow_out = beta1 * beta1_pow_[0]; + *beta2_pow_out = beta2 * beta2_pow_[0]; +} + +template +__global__ void SparseAdamCUDAKernelREG( + T beta1, T beta2, T epsilon, const T beta1_pow, const T beta2_pow, + const T* mom1_, T* mom1_out_, const T* mom2_, T* mom2_out_, const T* lr_, + const T* grad_, const T* param_, T* param_out_, const int64_t* rows_, + int64_t row_numel, int64_t row_count, bool lazy_mode, int ndim) { + int id = blockIdx.x * blockDim.x + threadIdx.x; + T lr = *lr_; + lr *= sqrt(1 - beta2_pow) / (1 - beta1_pow); + + for (; id < ndim; id += blockDim.x * gridDim.x) { + auto row_idx = + math::BinarySearch(rows_, row_count, id / row_numel); + if (lazy_mode && row_idx < 0) { + return; + } else { + T mom1 = mom1_[id]; + T mom2 = mom2_[id]; + T p = param_[id]; + T g = row_idx >= 0 ? grad_[row_idx * row_numel + id % row_numel] : 0; + mom1 = beta1 * mom1 + (1 - beta1) * g; + mom2 = beta2 * mom2 + (1 - beta2) * g * g; + p -= lr * (mom1 / (sqrt(mom2) + epsilon)); + + // Write back to global memory + mom1_out_[id] = mom1; + mom2_out_[id] = mom2; + param_out_[id] = p; + } + } +} + +template +class AdamOpCUDAKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + const auto* param_var = ctx.InputVar("Param"); + PADDLE_ENFORCE_EQ(param_var->IsType(), true, + platform::errors::InvalidArgument( + "The Var(%s)'s type should be LoDTensor, " + "but the received is %s", + ctx.InputNames("Param").front(), + framework::ToTypeName(param_var->Type()))); + + using paddle::framework::LoDTensor; + using paddle::operators::detail::Ref; + + int64_t min_row_size_to_use_multithread = + ctx.Attr("min_row_size_to_use_multithread"); + bool lazy_mode = ctx.Attr("lazy_mode"); + T epsilon = static_cast(ctx.Attr("epsilon")); + auto* param = ctx.Input("Param"); + auto* grad_var = ctx.InputVar("Grad"); + auto* mom1 = ctx.Input("Moment1"); + auto* mom2 = ctx.Input("Moment2"); + auto* lr = ctx.Input("LearningRate"); + + auto* beta1_pow = ctx.Input("Beta1Pow"); + auto* beta2_pow = ctx.Input("Beta2Pow"); + + auto* param_out = ctx.Output("ParamOut"); + auto* mom1_out = ctx.Output("Moment1Out"); + auto* mom2_out = ctx.Output("Moment2Out"); + auto* beta1_pow_out = ctx.Output("Beta1PowOut"); + auto* beta2_pow_out = ctx.Output("Beta2PowOut"); + + T beta1 = static_cast(ctx.Attr("beta1")); + if (ctx.HasInput("Beta1Tensor")) { + auto* beta1_tensor = ctx.Input("Beta1Tensor"); + beta1 = static_cast(GetAttrFromTensor(beta1_tensor)); + } + T beta2 = static_cast(ctx.Attr("beta2")); + if (ctx.HasInput("Beta2Tensor")) { + auto* beta2_tensor = ctx.Input("Beta2Tensor"); + beta2 = static_cast(GetAttrFromTensor(beta2_tensor)); + } + VLOG(3) << "beta1_pow.numel() : " << beta1_pow->numel() + << "beta2_pow.numel() : " << beta2_pow->numel(); + VLOG(3) << "param.numel(): " << param->numel(); + PADDLE_ENFORCE_EQ(beta1_pow_out->numel(), 1, + platform::errors::InvalidArgument( + "beta1 pow output size should be 1, but received " + "value is:%d.", + beta1_pow_out->numel())); + + PADDLE_ENFORCE_EQ(beta2_pow_out->numel(), 1, + platform::errors::InvalidArgument( + "beta2 pow output size should be 1, but received " + "value is:%d.", + beta2_pow_out->numel())); + auto& dev_ctx = ctx.template device_context(); + + if (grad_var->IsType()) { + auto* grad = ctx.Input("Grad"); + + // update param and moment + int threads = 512; + int blocks = (param->numel() + threads - 1) / threads; + + if (beta1_pow->place() == platform::CPUPlace() && + beta2_pow->place() == platform::CPUPlace()) { + // Compute with betapow in REG + AdamKernelREG<<>>( + beta1, beta2, epsilon, *beta1_pow->data(), *beta2_pow->data(), + mom1->data(), mom1_out->mutable_data(ctx.GetPlace()), + mom2->data(), mom2_out->mutable_data(ctx.GetPlace()), + lr->data(), grad->data(), param->data(), + param_out->mutable_data(ctx.GetPlace()), param->numel()); + // Cpu update + beta1_pow_out->mutable_data(platform::CPUPlace())[0] = + beta1 * beta1_pow->data()[0]; + beta2_pow_out->mutable_data(platform::CPUPlace())[0] = + beta2 * beta2_pow->data()[0]; + } else { + AdamKernelMEM<<>>( + beta1, beta2, epsilon, beta1_pow->data(), beta2_pow->data(), + mom1->data(), mom1_out->mutable_data(ctx.GetPlace()), + mom2->data(), mom2_out->mutable_data(ctx.GetPlace()), + lr->data(), grad->data(), param->data(), + param_out->mutable_data(ctx.GetPlace()), param->numel()); + // Update with gpu + UpdateBetaPow<<<1, 32, 0, dev_ctx.stream()>>>( + beta1, beta2, beta1_pow->data(), beta2_pow->data(), + beta1_pow_out->mutable_data(ctx.GetPlace()), + beta2_pow_out->mutable_data(ctx.GetPlace())); + } + + } else if (grad_var->IsType()) { + auto* grad = ctx.Input("Grad"); + if (grad->rows().size() == 0) { + VLOG(3) << "grad row size is 0!!"; + return; + } + + std::vector cpu_rows(grad->rows().begin(), grad->rows().end()); + bool is_strict_sorted = true; + for (size_t i = 1; i < cpu_rows.size(); ++i) { + if (cpu_rows[i - 1] >= cpu_rows[i]) { + is_strict_sorted = false; + break; + } + } + + framework::SelectedRows tmp_grad_merge; + const framework::SelectedRows* grad_merge_ptr; + if (is_strict_sorted) { + grad_merge_ptr = grad; + } else { + // merge duplicated rows if any. + // The rows of grad_merge have been sorted inside MergeAdd functor + scatter::MergeAdd merge_func; + merge_func(ctx.template device_context(), + *grad, &tmp_grad_merge, true); + grad_merge_ptr = &tmp_grad_merge; + } + auto& grad_merge = *grad_merge_ptr; + auto& grad_tensor = grad_merge.value(); + const T* grad_data = grad_tensor.template data(); + const int64_t* rows = grad_merge.rows().Data(ctx.GetPlace()); + auto row_numel = grad_tensor.numel() / grad_merge.rows().size(); + + if (beta1_pow->place() == platform::CPUPlace() && + beta2_pow->place() == platform::CPUPlace()) { + int threads = 512; + int ndim = param->numel(); + int blocks = (ndim + threads - 1) / threads; + + SparseAdamCUDAKernelREG<<>>( + beta1, beta2, epsilon, *beta1_pow->data(), *beta2_pow->data(), + mom1->data(), mom1_out->mutable_data(ctx.GetPlace()), + mom2->data(), mom2_out->mutable_data(ctx.GetPlace()), + lr->data(), grad_data, param->data(), + param_out->mutable_data(ctx.GetPlace()), rows, row_numel, + grad_merge.rows().size(), lazy_mode, ndim); + // Update with cpu + beta1_pow_out->mutable_data(platform::CPUPlace())[0] = + beta1 * beta1_pow->data()[0]; + beta2_pow_out->mutable_data(platform::CPUPlace())[0] = + beta2 * beta2_pow->data()[0]; + } else { + SparseAdamFunctor functor( + beta1, beta2, epsilon, beta1_pow->data(), beta2_pow->data(), + mom1->data(), mom1_out->mutable_data(ctx.GetPlace()), + mom2->data(), mom2_out->mutable_data(ctx.GetPlace()), + lr->data(), grad_data, param->data(), + param_out->mutable_data(ctx.GetPlace()), rows, row_numel, + grad_merge.rows().size(), lazy_mode); + + // FIXME(minqiyang): remove BinarySearch in GPU later + platform::ForRange for_range( + static_cast( + ctx.device_context()), + param->numel()); + for_range(functor); + // update beta1 and beta2 + UpdateBetaPow<<<1, 32, 0, dev_ctx.stream()>>>( + beta1, beta2, beta1_pow->data(), beta2_pow->data(), + beta1_pow_out->mutable_data(ctx.GetPlace()), + beta2_pow_out->mutable_data(ctx.GetPlace())); + } + } else { + PADDLE_THROW(platform::errors::InvalidArgument( + "Variable type not supported by adam_op")); + } + } +}; + +} // namespace operators +} // namespace paddle + namespace ops = paddle::operators; -REGISTER_OP_CUDA_KERNEL( - adam, ops::AdamOpKernel, - ops::AdamOpKernel); +REGISTER_OP_CUDA_KERNEL(adam, ops::AdamOpCUDAKernel, + ops::AdamOpCUDAKernel); diff --git a/paddle/fluid/operators/optimizers/adam_op.h b/paddle/fluid/operators/optimizers/adam_op.h index 99338b1e0c5..461c94976ff 100644 --- a/paddle/fluid/operators/optimizers/adam_op.h +++ b/paddle/fluid/operators/optimizers/adam_op.h @@ -15,6 +15,7 @@ limitations under the License. */ #pragma once #include // for sqrt in CPU and CUDA #include +#include #include #include #include "paddle/fluid/framework/op_registry.h" @@ -46,6 +47,9 @@ class AdamOp : public framework::OperatorWithKernel { void InferShape(framework::InferShapeContext* ctx) const override; framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override; + framework::OpKernelType GetKernelTypeForVar( + const std::string& var_name, const framework::Tensor& tensor, + const framework::OpKernelType& expected_kernel_type) const override; }; struct GPUAdam; @@ -54,43 +58,6 @@ struct CPUAdam; template class AdamFunctor; -template -class BetaPowFunctor { - private: - T beta1_; - T beta2_; - const T* beta1_pow_; - const T* beta2_pow_; - T* beta1_pow_out_; - T* beta2_pow_out_; - - public: - BetaPowFunctor(T beta1, T beta2, const T* beta1_pow, const T* beta2_pow, - T* beta1_pow_out, T* beta2_pow_out) - : beta1_(beta1), - beta2_(beta2), - beta1_pow_(beta1_pow), - beta2_pow_(beta2_pow), - beta1_pow_out_(beta1_pow_out), - beta2_pow_out_(beta2_pow_out) {} - - inline HOSTDEVICE void update_step(size_t i) const { - T beta1_pow_i = beta1_pow_[i]; - T beta2_pow_i = beta2_pow_[i]; - - beta1_pow_out_[i] = beta1_pow_i * beta1_; - beta2_pow_out_[i] = beta2_pow_i * beta2_; - } - - inline HOSTDEVICE void operator()(size_t i) const { update_step(i); } - - inline HOSTDEVICE void apply_update(size_t limit) const { - for (size_t i = 0; i < limit; ++i) { - update_step(i); - } - } -}; - template class AdamFunctor { private: @@ -423,29 +390,20 @@ class AdamOpKernel : public framework::OpKernel { ctx.Attr("min_row_size_to_use_multithread"); bool lazy_mode = ctx.Attr("lazy_mode"); T epsilon = static_cast(ctx.Attr("epsilon")); - auto& param = Ref(ctx.Input("Param"), "Must set Param"); - // auto& grad = Ref(ctx.Input("Grad"), "Must set Grad"); + auto* param = ctx.Input("Param"); auto* grad_var = ctx.InputVar("Grad"); - auto& mom1 = Ref(ctx.Input("Moment1"), "Must set Moment1"); - auto& mom2 = Ref(ctx.Input("Moment2"), "Must set Moment2"); - auto& lr = - Ref(ctx.Input("LearningRate"), "Must set LearningRate"); - - auto& beta1_pow = - Ref(ctx.Input("Beta1Pow"), "Must set Beta1Pow"); - auto& beta2_pow = - Ref(ctx.Input("Beta2Pow"), "Must set Beta2Pow"); - - auto& param_out = - Ref(ctx.Output("ParamOut"), "Must set ParamOut"); - auto& mom1_out = - Ref(ctx.Output("Moment1Out"), "Must set Moment1Out"); - auto& mom2_out = - Ref(ctx.Output("Moment2Out"), "Must set Moment1Out"); - auto& beta1_pow_out = - Ref(ctx.Output("Beta1PowOut"), "Must set Beta1PowOut"); - auto& beta2_pow_out = - Ref(ctx.Output("Beta2PowOut"), "Must set Beta2PowOut"); + auto* mom1 = ctx.Input("Moment1"); + auto* mom2 = ctx.Input("Moment2"); + auto* lr = ctx.Input("LearningRate"); + + auto* beta1_pow = ctx.Input("Beta1Pow"); + auto* beta2_pow = ctx.Input("Beta2Pow"); + + auto* param_out = ctx.Output("ParamOut"); + auto* mom1_out = ctx.Output("Moment1Out"); + auto* mom2_out = ctx.Output("Moment2Out"); + auto* beta1_pow_out = ctx.Output("Beta1PowOut"); + auto* beta2_pow_out = ctx.Output("Beta2PowOut"); T beta1 = static_cast(ctx.Attr("beta1")); if (ctx.HasInput("Beta1Tensor")) { @@ -457,60 +415,45 @@ class AdamOpKernel : public framework::OpKernel { auto* beta2_tensor = ctx.Input("Beta2Tensor"); beta2 = static_cast(GetAttrFromTensor(beta2_tensor)); } - VLOG(3) << "beta1_pow.numel() : " << beta1_pow.numel() - << "beta2_pow.numel() : " << beta2_pow.numel(); - VLOG(3) << "param.numel(): " << param.numel(); - BetaPowFunctor beta_functor( - beta1, beta2, beta1_pow.template data(), - beta2_pow.template data(), - beta1_pow_out.template mutable_data(ctx.GetPlace()), - beta2_pow_out.template mutable_data(ctx.GetPlace())); + VLOG(3) << "beta1_pow.numel() : " << beta1_pow->numel() + << "beta2_pow.numel() : " << beta2_pow->numel(); + VLOG(3) << "param.numel(): " << param->numel(); + + PADDLE_ENFORCE_EQ(beta1_pow_out->numel(), 1, + platform::errors::InvalidArgument( + "beta1 pow output size should be 1, but received " + "value is:%d.", + beta1_pow_out->numel())); + + PADDLE_ENFORCE_EQ(beta2_pow_out->numel(), 1, + platform::errors::InvalidArgument( + "beta2 pow output size should be 1, but received " + "value is:%d.", + beta2_pow_out->numel())); if (grad_var->IsType()) { - auto& grad = Ref(ctx.Input("Grad"), "Must set Grad"); - - if (platform::is_cpu_place(ctx.GetPlace())) { - AdamFunctor functor( - beta1, beta2, epsilon, beta1_pow.template data(), - beta2_pow.template data(), mom1.template data(), - mom1_out.template mutable_data(ctx.GetPlace()), - mom2.template data(), - mom2_out.template mutable_data(ctx.GetPlace()), - lr.template data(), grad.template data(), - param.template data(), - param_out.template mutable_data(ctx.GetPlace())); - functor(param.numel()); - beta_functor.apply_update(beta2_pow.numel()); - } else if (platform::is_gpu_place(ctx.GetPlace())) { - AdamFunctor functor( - beta1, beta2, epsilon, beta1_pow.template data(), - beta2_pow.template data(), mom1.template data(), - mom1_out.template mutable_data(ctx.GetPlace()), - mom2.template data(), - mom2_out.template mutable_data(ctx.GetPlace()), - lr.template data(), grad.template data(), - param.template data(), - param_out.template mutable_data(ctx.GetPlace())); - // update param and moment - platform::ForRange for_range( - static_cast(ctx.device_context()), - param.numel()); - for_range(functor); - // update beta1 and beta2 - platform::ForRange for_range_beta( - static_cast(ctx.device_context()), - beta2_pow.numel()); - for_range_beta(beta_functor); - } + auto* grad = ctx.Input("Grad"); + + AdamFunctor functor( + beta1, beta2, epsilon, beta1_pow->data(), beta2_pow->data(), + mom1->data(), mom1_out->mutable_data(ctx.GetPlace()), + mom2->data(), mom2_out->mutable_data(ctx.GetPlace()), + lr->data(), grad->data(), param->data(), + param_out->mutable_data(ctx.GetPlace())); + functor(param->numel()); + beta1_pow_out->mutable_data(ctx.GetPlace())[0] = + beta1 * beta1_pow->data()[0]; + beta2_pow_out->mutable_data(ctx.GetPlace())[0] = + beta2 * beta2_pow->data()[0]; + } else if (grad_var->IsType()) { - auto& grad = - Ref(ctx.Input("Grad"), "Must set Grad"); - if (grad.rows().size() == 0) { + auto* grad = ctx.Input("Grad"); + if (grad->rows().size() == 0) { VLOG(3) << "grad row size is 0!!"; return; } - std::vector cpu_rows(grad.rows().begin(), grad.rows().end()); + std::vector cpu_rows(grad->rows().begin(), grad->rows().end()); bool is_strict_sorted = true; for (size_t i = 1; i < cpu_rows.size(); ++i) { if (cpu_rows[i - 1] >= cpu_rows[i]) { @@ -522,12 +465,12 @@ class AdamOpKernel : public framework::OpKernel { framework::SelectedRows tmp_grad_merge; const framework::SelectedRows* grad_merge_ptr; if (is_strict_sorted) { - grad_merge_ptr = &grad; + grad_merge_ptr = grad; } else { // merge duplicated rows if any. // The rows of grad_merge have been sorted inside MergeAdd functor scatter::MergeAdd merge_func; - merge_func(ctx.template device_context(), grad, + merge_func(ctx.template device_context(), *grad, &tmp_grad_merge, true); grad_merge_ptr = &tmp_grad_merge; } @@ -538,112 +481,89 @@ class AdamOpKernel : public framework::OpKernel { const int64_t* rows = grad_merge.rows().Data(ctx.GetPlace()); auto row_numel = grad_tensor.numel() / grad_merge.rows().size(); - if (platform::is_cpu_place(ctx.GetPlace())) { - SparseAdamFunctor functor( - beta1, beta2, epsilon, beta1_pow.template data(), - beta2_pow.template data(), mom1.template data(), - mom1_out.template mutable_data(ctx.GetPlace()), - mom2.template data(), - mom2_out.template mutable_data(ctx.GetPlace()), - lr.template data(), grad_data, param.template data(), - param_out.template mutable_data(ctx.GetPlace()), rows, row_numel, - grad_merge.rows().size(), lazy_mode); - // update beta1 and beta2 - beta_functor.apply_update(beta2_pow.numel()); - if (lazy_mode) { - VLOG(3) << "run cpu lazy mode"; - size_t row_count = grad_merge.rows().size(); - std::vector cpu_rows(grad_merge.rows()); - for (size_t row_index = 0; row_index < row_count; ++row_index) { - for (size_t offset = 0; offset < row_numel; ++offset) { - size_t i = cpu_rows[row_index] * row_numel + offset; - functor.adam_update(i, grad_data[row_index * row_numel + offset]); - } + SparseAdamFunctor functor( + beta1, beta2, epsilon, beta1_pow->data(), beta2_pow->data(), + mom1->data(), mom1_out->mutable_data(ctx.GetPlace()), + mom2->data(), mom2_out->mutable_data(ctx.GetPlace()), + lr->data(), grad_data, param->data(), + param_out->mutable_data(ctx.GetPlace()), rows, row_numel, + grad_merge.rows().size(), lazy_mode); + // update beta1 and beta2 + beta1_pow_out->mutable_data(ctx.GetPlace())[0] = + beta1 * beta1_pow->data()[0]; + beta2_pow_out->mutable_data(ctx.GetPlace())[0] = + beta2 * beta2_pow->data()[0]; + if (lazy_mode) { + VLOG(3) << "run cpu lazy mode"; + size_t row_count = grad_merge.rows().size(); + std::vector cpu_rows(grad_merge.rows()); + for (size_t row_index = 0; row_index < row_count; ++row_index) { + for (size_t offset = 0; offset < row_numel; ++offset) { + size_t i = cpu_rows[row_index] * row_numel + offset; + functor.adam_update(i, grad_data[row_index * row_numel + offset]); } } + } #ifndef _WIN32 - else if (FLAGS_inner_op_parallelism > 1 && // NOLINT - min_row_size_to_use_multithread > 0 && - param.dims()[0] > min_row_size_to_use_multithread) { - VLOG(3) << "use multi thread, inner_op_parallelism=" - << FLAGS_inner_op_parallelism - << " min_row_size_to_use_multithread=" - << min_row_size_to_use_multithread; - if (FLAGS_inner_op_parallelism > 10) { - VLOG(1) << "FLAGS_inner_op_parallelism " - << FLAGS_inner_op_parallelism << " is two large!"; - } - auto& grad_rows = grad_merge.rows(); - std::unordered_map row_id_to_grad_row_offset; - size_t param_row_count = param.numel() / row_numel; - if (param_row_count < 1000) { - VLOG(1) << "param_row_count should be larger then 1000 to use " - "multi thread, currently " - << param_row_count; + else if (FLAGS_inner_op_parallelism > 1 && // NOLINT + min_row_size_to_use_multithread > 0 && + param->dims()[0] > min_row_size_to_use_multithread) { + VLOG(3) << "use multi thread, inner_op_parallelism=" + << FLAGS_inner_op_parallelism + << " min_row_size_to_use_multithread=" + << min_row_size_to_use_multithread; + if (FLAGS_inner_op_parallelism > 10) { + VLOG(1) << "FLAGS_inner_op_parallelism " << FLAGS_inner_op_parallelism + << " is two large!"; + } + auto& grad_rows = grad_merge.rows(); + std::unordered_map row_id_to_grad_row_offset; + size_t param_row_count = param->numel() / row_numel; + if (param_row_count < 1000) { + VLOG(1) << "param_row_count should be larger then 1000 to use " + "multi thread, currently " + << param_row_count; + } + for (size_t i = 0; i < grad_rows.size(); ++i) { + row_id_to_grad_row_offset[grad_rows[i]] = i; + } + std::vector> fs; + int64_t line_in_each_thread = + param_row_count / FLAGS_inner_op_parallelism + 1; + for (int i = 0; i < FLAGS_inner_op_parallelism; ++i) { + int64_t start = i * line_in_each_thread; + int64_t end = (i + 1) * line_in_each_thread; + if (start >= static_cast(param_row_count)) { + break; } - for (size_t i = 0; i < grad_rows.size(); ++i) { - row_id_to_grad_row_offset[grad_rows[i]] = i; + if (end > static_cast(param_row_count)) { + end = static_cast(param_row_count); } - std::vector> fs; - int64_t line_in_each_thread = - param_row_count / FLAGS_inner_op_parallelism + 1; - for (int i = 0; i < FLAGS_inner_op_parallelism; ++i) { - int64_t start = i * line_in_each_thread; - int64_t end = (i + 1) * line_in_each_thread; - if (start >= static_cast(param_row_count)) { - break; + fs.push_back(framework::Async([&functor, &row_id_to_grad_row_offset, + &grad_data, row_numel, start, end]() { + for (int64_t row_id = start; row_id < end; ++row_id) { + auto iter = row_id_to_grad_row_offset.find(row_id); + if (iter != row_id_to_grad_row_offset.end()) { + for (size_t row_offset = 0U; row_offset < row_numel; + ++row_offset) { + functor.adam_update( + row_id * row_numel + row_offset, + grad_data[iter->second * row_numel + row_offset]); + } + } else { + for (size_t row_offset = 0U; row_offset < row_numel; + ++row_offset) { + functor.adam_update(row_id * row_numel + row_offset, 0); + } + } } - if (end > static_cast(param_row_count)) { - end = static_cast(param_row_count); - } - fs.push_back( - framework::Async([&functor, &row_id_to_grad_row_offset, - &grad_data, row_numel, start, end]() { - for (int64_t row_id = start; row_id < end; ++row_id) { - auto iter = row_id_to_grad_row_offset.find(row_id); - if (iter != row_id_to_grad_row_offset.end()) { - for (size_t row_offset = 0U; row_offset < row_numel; - ++row_offset) { - functor.adam_update( - row_id * row_numel + row_offset, - grad_data[iter->second * row_numel + row_offset]); - } - } else { - for (size_t row_offset = 0U; row_offset < row_numel; - ++row_offset) { - functor.adam_update(row_id * row_numel + row_offset, 0); - } - } - } - })); - } - for (size_t i = 0; i < fs.size(); ++i) fs[i].wait(); + })); } -#endif // !_WIN32 - else { // NOLINT - functor(param.numel()); - } - } else if (platform::is_gpu_place(ctx.GetPlace())) { - SparseAdamFunctor functor( - beta1, beta2, epsilon, beta1_pow.template data(), - beta2_pow.template data(), mom1.template data(), - mom1_out.template mutable_data(ctx.GetPlace()), - mom2.template data(), - mom2_out.template mutable_data(ctx.GetPlace()), - lr.template data(), grad_data, param.template data(), - param_out.template mutable_data(ctx.GetPlace()), rows, row_numel, - grad_merge.rows().size(), lazy_mode); - - // FIXME(minqiyang): remove BinarySearch in GPU later - platform::ForRange for_range( - static_cast(ctx.device_context()), - param.numel()); - for_range(functor); - // update beta1 and beta2 - platform::ForRange for_range_beta( - static_cast(ctx.device_context()), - beta2_pow.numel()); - for_range_beta(beta_functor); + for (size_t i = 0; i < fs.size(); ++i) fs[i].wait(); + } +#endif // !_WIN32 + else { // NOLINT + functor(param->numel()); } } else { PADDLE_THROW("Variable type not supported by adam_op"); diff --git a/paddle/fluid/pybind/tensor_py.h b/paddle/fluid/pybind/tensor_py.h index 9e5dc638516..d1d7681b7ba 100644 --- a/paddle/fluid/pybind/tensor_py.h +++ b/paddle/fluid/pybind/tensor_py.h @@ -187,12 +187,21 @@ void SetTensorFromPyArrayT( } } else { #ifdef PADDLE_WITH_CUDA - auto dst = self->mutable_data(place); + T *dst; + if (array.nbytes() <= 4 && !paddle::platform::is_cuda_pinned_place(place)) { + dst = self->mutable_data(platform::CPUPlace()); + } else { + dst = self->mutable_data(place); + } if (paddle::platform::is_cuda_pinned_place(place)) { std::memcpy(dst, array.data(), array.nbytes()); } else if (paddle::platform::is_gpu_place(place)) { - paddle::platform::GpuMemcpySync(dst, array.data(), array.nbytes(), - cudaMemcpyHostToDevice); + if (array.nbytes() <= 4) { + std::memcpy(dst, array.data(), array.nbytes()); + } else { + paddle::platform::GpuMemcpySync(dst, array.data(), array.nbytes(), + cudaMemcpyHostToDevice); + } } else { PADDLE_THROW( "Incompatible place type: Tensor.set() supports CPUPlace, CUDAPlace " diff --git a/python/paddle/fluid/optimizer.py b/python/paddle/fluid/optimizer.py index dd7995c6f7f..0f1d593fa17 100644 --- a/python/paddle/fluid/optimizer.py +++ b/python/paddle/fluid/optimizer.py @@ -404,7 +404,8 @@ class Optimizer(object): dtype=None, fill_value=0.0, shape=None, - type=None): + type=None, + force_cpu=False): """Utility function to add an accumulator for a parameter Args: @@ -438,7 +439,9 @@ class Optimizer(object): shape=shape, belong_to_optimizer=True) self.helper.set_variable_initializer( - var, initializer=Constant(value=float(fill_value))) + var, + initializer=Constant( + value=float(fill_value), force_cpu=force_cpu)) if framework.in_dygraph_mode(): if len(self._accumulators_holder) > 0: @@ -1790,14 +1793,14 @@ class AdamOptimizer(Optimizer): fill_value=0.9 if isinstance(self._beta1, Variable) \ else self._beta1, shape=[1], - type=core.VarDesc.VarType.LOD_TENSOR) + type=core.VarDesc.VarType.LOD_TENSOR, force_cpu=True) self._add_accumulator( name=self._beta2_pow_acc_str, param=p, fill_value=0.999 if isinstance(self._beta2, Variable) \ else self._beta2, shape=[1], - type=core.VarDesc.VarType.LOD_TENSOR) + type=core.VarDesc.VarType.LOD_TENSOR, force_cpu=True) def _append_optimize_op(self, block, param_and_grad): assert isinstance(block, framework.Block) -- GitLab