未验证 提交 72dde4ab 编写于 作者: Z zhaoyuchen2018 提交者: GitHub

Refine adam op to improve performance, test=develop (#22346)

* Refine adam op, test=develop

* Fuse kernels together to reduce cpu time.

* Refine paddle enforce, test=develop

* Remove some comments, test=develop

* Refine code,test=develop

* Refine cuda kernel, test=develop

* Refine code according to comments, test=develop
上级 8c381cd9
......@@ -145,7 +145,8 @@ class CoalesceTensorOpKernel : public framework::OpKernel<T> {
auto size = lod_tensors[i]->numel();
PADDLE_ENFORCE_GT(size, 0);
ss << "input(" << var_names[i] << ") dim:(" << lod_tensors[i]->dims()
<< "), ";
<< ") "
<< " addres:" << lod_tensors[i]->data<void>() << ", ";
*numel += platform::Alignment(static_cast<size_t>(size) * size_of_dtype,
place) /
size_of_dtype;
......@@ -160,6 +161,15 @@ class CoalesceTensorOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {}
protected:
framework::OpKernelType GetKernelTypeForVar(
const std::string &var_name, const framework::Tensor &tensor,
const framework::OpKernelType &expected_kernel_type) const override {
return framework::OpKernelType(expected_kernel_type.data_type_,
expected_kernel_type.place_,
tensor.layout());
}
};
class CoalesceTensorOpMaker : public framework::OpProtoAndCheckerMaker {
......
......@@ -19,7 +19,7 @@ namespace operators {
using Tensor = framework::Tensor;
void AdamOp::InferShape(framework::InferShapeContext* ctx) const {
void AdamOp::InferShape(framework::InferShapeContext *ctx) const {
PADDLE_ENFORCE_EQ(
ctx->HasInput("Param"), true,
platform::errors::NotFound("Input(Param) of AdamOp should not be null."));
......@@ -126,11 +126,22 @@ void AdamOp::InferShape(framework::InferShapeContext* ctx) const {
}
framework::OpKernelType AdamOp::GetExpectedKernelType(
const framework::ExecutionContext& ctx) const {
const framework::ExecutionContext &ctx) const {
auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "Param");
return framework::OpKernelType(input_data_type, ctx.GetPlace());
}
framework::OpKernelType AdamOp::GetKernelTypeForVar(
const std::string &var_name, const framework::Tensor &tensor,
const framework::OpKernelType &expected_kernel_type) const {
if (var_name == "Beta1Pow" || var_name == "Beta2Pow") {
return expected_kernel_type;
} else {
return framework::OpKernelType(expected_kernel_type.data_type_,
tensor.place(), tensor.layout());
}
}
class AdamOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
......
......@@ -13,7 +13,286 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/optimizers/adam_op.h"
namespace paddle {
namespace operators {
template <typename T>
__global__ void AdamKernelREG(T beta1, T beta2, T epsilon, T beta1_pow_,
T beta2_pow_, const T* moment1, T* moment1_out,
const T* moment2, T* moment2_out, const T* lr_,
const T* grad, const T* param, T* param_out,
int ndim) {
T lr = *lr_;
T beta1_pow = beta1_pow_;
T beta2_pow = beta2_pow_;
lr *=
sqrt(static_cast<T>(1.0) - beta2_pow) / (static_cast<T>(1.0) - beta1_pow);
int id = blockIdx.x * blockDim.x + threadIdx.x;
for (; id < ndim; id += gridDim.x * blockDim.x) {
T p = param[id];
T g = grad[id];
T mom1 = moment1[id];
T mom2 = moment2[id];
mom1 = beta1 * mom1 + (static_cast<T>(1.0) - beta1) * g;
mom2 = beta2 * mom2 + (static_cast<T>(1.0) - beta2) * g * g;
p -= lr * (mom1 / (sqrt(mom2) + epsilon));
moment1_out[id] = mom1;
moment2_out[id] = mom2;
param_out[id] = p;
}
}
template <typename T>
__global__ void AdamKernelMEM(T beta1, T beta2, T epsilon, const T* beta1_pow_,
const T* beta2_pow_, const T* moment1,
T* moment1_out, const T* moment2, T* moment2_out,
const T* lr_, const T* grad, const T* param,
T* param_out, int ndim) {
T lr = *lr_;
T beta1_pow = *beta1_pow_;
T beta2_pow = *beta2_pow_;
lr *=
sqrt(static_cast<T>(1.0) - beta2_pow) / (static_cast<T>(1.0) - beta1_pow);
int id = blockIdx.x * blockDim.x + threadIdx.x;
for (; id < ndim; id += gridDim.x * blockDim.x) {
T p = param[id];
T g = grad[id];
T mom1 = moment1[id];
T mom2 = moment2[id];
mom1 = beta1 * mom1 + (static_cast<T>(1.0) - beta1) * g;
mom2 = beta2 * mom2 + (static_cast<T>(1.0) - beta2) * g * g;
p -= lr * (mom1 / (sqrt(mom2) + epsilon));
moment1_out[id] = mom1;
moment2_out[id] = mom2;
param_out[id] = p;
}
}
template <typename T>
__global__ void UpdateBetaPow(T beta1, T beta2, const T* beta1_pow_,
const T* beta2_pow_, T* beta1_pow_out,
T* beta2_pow_out) {
*beta1_pow_out = beta1 * beta1_pow_[0];
*beta2_pow_out = beta2 * beta2_pow_[0];
}
template <typename T>
__global__ void SparseAdamCUDAKernelREG(
T beta1, T beta2, T epsilon, const T beta1_pow, const T beta2_pow,
const T* mom1_, T* mom1_out_, const T* mom2_, T* mom2_out_, const T* lr_,
const T* grad_, const T* param_, T* param_out_, const int64_t* rows_,
int64_t row_numel, int64_t row_count, bool lazy_mode, int ndim) {
int id = blockIdx.x * blockDim.x + threadIdx.x;
T lr = *lr_;
lr *= sqrt(1 - beta2_pow) / (1 - beta1_pow);
for (; id < ndim; id += blockDim.x * gridDim.x) {
auto row_idx =
math::BinarySearch<int64_t>(rows_, row_count, id / row_numel);
if (lazy_mode && row_idx < 0) {
return;
} else {
T mom1 = mom1_[id];
T mom2 = mom2_[id];
T p = param_[id];
T g = row_idx >= 0 ? grad_[row_idx * row_numel + id % row_numel] : 0;
mom1 = beta1 * mom1 + (1 - beta1) * g;
mom2 = beta2 * mom2 + (1 - beta2) * g * g;
p -= lr * (mom1 / (sqrt(mom2) + epsilon));
// Write back to global memory
mom1_out_[id] = mom1;
mom2_out_[id] = mom2;
param_out_[id] = p;
}
}
}
template <typename T>
class AdamOpCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
const auto* param_var = ctx.InputVar("Param");
PADDLE_ENFORCE_EQ(param_var->IsType<framework::LoDTensor>(), true,
platform::errors::InvalidArgument(
"The Var(%s)'s type should be LoDTensor, "
"but the received is %s",
ctx.InputNames("Param").front(),
framework::ToTypeName(param_var->Type())));
using paddle::framework::LoDTensor;
using paddle::operators::detail::Ref;
int64_t min_row_size_to_use_multithread =
ctx.Attr<int64_t>("min_row_size_to_use_multithread");
bool lazy_mode = ctx.Attr<bool>("lazy_mode");
T epsilon = static_cast<T>(ctx.Attr<float>("epsilon"));
auto* param = ctx.Input<LoDTensor>("Param");
auto* grad_var = ctx.InputVar("Grad");
auto* mom1 = ctx.Input<LoDTensor>("Moment1");
auto* mom2 = ctx.Input<LoDTensor>("Moment2");
auto* lr = ctx.Input<LoDTensor>("LearningRate");
auto* beta1_pow = ctx.Input<LoDTensor>("Beta1Pow");
auto* beta2_pow = ctx.Input<LoDTensor>("Beta2Pow");
auto* param_out = ctx.Output<LoDTensor>("ParamOut");
auto* mom1_out = ctx.Output<LoDTensor>("Moment1Out");
auto* mom2_out = ctx.Output<LoDTensor>("Moment2Out");
auto* beta1_pow_out = ctx.Output<LoDTensor>("Beta1PowOut");
auto* beta2_pow_out = ctx.Output<LoDTensor>("Beta2PowOut");
T beta1 = static_cast<T>(ctx.Attr<float>("beta1"));
if (ctx.HasInput("Beta1Tensor")) {
auto* beta1_tensor = ctx.Input<framework::Tensor>("Beta1Tensor");
beta1 = static_cast<T>(GetAttrFromTensor(beta1_tensor));
}
T beta2 = static_cast<T>(ctx.Attr<float>("beta2"));
if (ctx.HasInput("Beta2Tensor")) {
auto* beta2_tensor = ctx.Input<framework::Tensor>("Beta2Tensor");
beta2 = static_cast<T>(GetAttrFromTensor(beta2_tensor));
}
VLOG(3) << "beta1_pow.numel() : " << beta1_pow->numel()
<< "beta2_pow.numel() : " << beta2_pow->numel();
VLOG(3) << "param.numel(): " << param->numel();
PADDLE_ENFORCE_EQ(beta1_pow_out->numel(), 1,
platform::errors::InvalidArgument(
"beta1 pow output size should be 1, but received "
"value is:%d.",
beta1_pow_out->numel()));
PADDLE_ENFORCE_EQ(beta2_pow_out->numel(), 1,
platform::errors::InvalidArgument(
"beta2 pow output size should be 1, but received "
"value is:%d.",
beta2_pow_out->numel()));
auto& dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
if (grad_var->IsType<framework::LoDTensor>()) {
auto* grad = ctx.Input<LoDTensor>("Grad");
// update param and moment
int threads = 512;
int blocks = (param->numel() + threads - 1) / threads;
if (beta1_pow->place() == platform::CPUPlace() &&
beta2_pow->place() == platform::CPUPlace()) {
// Compute with betapow in REG
AdamKernelREG<T><<<blocks, threads, 0, dev_ctx.stream()>>>(
beta1, beta2, epsilon, *beta1_pow->data<T>(), *beta2_pow->data<T>(),
mom1->data<T>(), mom1_out->mutable_data<T>(ctx.GetPlace()),
mom2->data<T>(), mom2_out->mutable_data<T>(ctx.GetPlace()),
lr->data<T>(), grad->data<T>(), param->data<T>(),
param_out->mutable_data<T>(ctx.GetPlace()), param->numel());
// Cpu update
beta1_pow_out->mutable_data<T>(platform::CPUPlace())[0] =
beta1 * beta1_pow->data<T>()[0];
beta2_pow_out->mutable_data<T>(platform::CPUPlace())[0] =
beta2 * beta2_pow->data<T>()[0];
} else {
AdamKernelMEM<T><<<blocks, threads, 0, dev_ctx.stream()>>>(
beta1, beta2, epsilon, beta1_pow->data<T>(), beta2_pow->data<T>(),
mom1->data<T>(), mom1_out->mutable_data<T>(ctx.GetPlace()),
mom2->data<T>(), mom2_out->mutable_data<T>(ctx.GetPlace()),
lr->data<T>(), grad->data<T>(), param->data<T>(),
param_out->mutable_data<T>(ctx.GetPlace()), param->numel());
// Update with gpu
UpdateBetaPow<T><<<1, 32, 0, dev_ctx.stream()>>>(
beta1, beta2, beta1_pow->data<T>(), beta2_pow->data<T>(),
beta1_pow_out->mutable_data<T>(ctx.GetPlace()),
beta2_pow_out->mutable_data<T>(ctx.GetPlace()));
}
} else if (grad_var->IsType<framework::SelectedRows>()) {
auto* grad = ctx.Input<framework::SelectedRows>("Grad");
if (grad->rows().size() == 0) {
VLOG(3) << "grad row size is 0!!";
return;
}
std::vector<int64_t> cpu_rows(grad->rows().begin(), grad->rows().end());
bool is_strict_sorted = true;
for (size_t i = 1; i < cpu_rows.size(); ++i) {
if (cpu_rows[i - 1] >= cpu_rows[i]) {
is_strict_sorted = false;
break;
}
}
framework::SelectedRows tmp_grad_merge;
const framework::SelectedRows* grad_merge_ptr;
if (is_strict_sorted) {
grad_merge_ptr = grad;
} else {
// merge duplicated rows if any.
// The rows of grad_merge have been sorted inside MergeAdd functor
scatter::MergeAdd<platform::CUDADeviceContext, T> merge_func;
merge_func(ctx.template device_context<platform::CUDADeviceContext>(),
*grad, &tmp_grad_merge, true);
grad_merge_ptr = &tmp_grad_merge;
}
auto& grad_merge = *grad_merge_ptr;
auto& grad_tensor = grad_merge.value();
const T* grad_data = grad_tensor.template data<T>();
const int64_t* rows = grad_merge.rows().Data(ctx.GetPlace());
auto row_numel = grad_tensor.numel() / grad_merge.rows().size();
if (beta1_pow->place() == platform::CPUPlace() &&
beta2_pow->place() == platform::CPUPlace()) {
int threads = 512;
int ndim = param->numel();
int blocks = (ndim + threads - 1) / threads;
SparseAdamCUDAKernelREG<T><<<blocks, threads, 0, dev_ctx.stream()>>>(
beta1, beta2, epsilon, *beta1_pow->data<T>(), *beta2_pow->data<T>(),
mom1->data<T>(), mom1_out->mutable_data<T>(ctx.GetPlace()),
mom2->data<T>(), mom2_out->mutable_data<T>(ctx.GetPlace()),
lr->data<T>(), grad_data, param->data<T>(),
param_out->mutable_data<T>(ctx.GetPlace()), rows, row_numel,
grad_merge.rows().size(), lazy_mode, ndim);
// Update with cpu
beta1_pow_out->mutable_data<T>(platform::CPUPlace())[0] =
beta1 * beta1_pow->data<T>()[0];
beta2_pow_out->mutable_data<T>(platform::CPUPlace())[0] =
beta2 * beta2_pow->data<T>()[0];
} else {
SparseAdamFunctor<T, GPUAdam> functor(
beta1, beta2, epsilon, beta1_pow->data<T>(), beta2_pow->data<T>(),
mom1->data<T>(), mom1_out->mutable_data<T>(ctx.GetPlace()),
mom2->data<T>(), mom2_out->mutable_data<T>(ctx.GetPlace()),
lr->data<T>(), grad_data, param->data<T>(),
param_out->mutable_data<T>(ctx.GetPlace()), rows, row_numel,
grad_merge.rows().size(), lazy_mode);
// FIXME(minqiyang): remove BinarySearch in GPU later
platform::ForRange<platform::CUDADeviceContext> for_range(
static_cast<const platform::CUDADeviceContext&>(
ctx.device_context()),
param->numel());
for_range(functor);
// update beta1 and beta2
UpdateBetaPow<T><<<1, 32, 0, dev_ctx.stream()>>>(
beta1, beta2, beta1_pow->data<T>(), beta2_pow->data<T>(),
beta1_pow_out->mutable_data<T>(ctx.GetPlace()),
beta2_pow_out->mutable_data<T>(ctx.GetPlace()));
}
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"Variable type not supported by adam_op"));
}
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
adam, ops::AdamOpKernel<paddle::platform::CUDADeviceContext, float>,
ops::AdamOpKernel<paddle::platform::CUDADeviceContext, double>);
REGISTER_OP_CUDA_KERNEL(adam, ops::AdamOpCUDAKernel<float>,
ops::AdamOpCUDAKernel<double>);
......@@ -15,6 +15,7 @@ limitations under the License. */
#pragma once
#include <math.h> // for sqrt in CPU and CUDA
#include <Eigen/Dense>
#include <string>
#include <unordered_map>
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
......@@ -46,6 +47,9 @@ class AdamOp : public framework::OperatorWithKernel {
void InferShape(framework::InferShapeContext* ctx) const override;
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override;
framework::OpKernelType GetKernelTypeForVar(
const std::string& var_name, const framework::Tensor& tensor,
const framework::OpKernelType& expected_kernel_type) const override;
};
struct GPUAdam;
......@@ -54,43 +58,6 @@ struct CPUAdam;
template <typename T, typename Flavour>
class AdamFunctor;
template <typename T>
class BetaPowFunctor {
private:
T beta1_;
T beta2_;
const T* beta1_pow_;
const T* beta2_pow_;
T* beta1_pow_out_;
T* beta2_pow_out_;
public:
BetaPowFunctor(T beta1, T beta2, const T* beta1_pow, const T* beta2_pow,
T* beta1_pow_out, T* beta2_pow_out)
: beta1_(beta1),
beta2_(beta2),
beta1_pow_(beta1_pow),
beta2_pow_(beta2_pow),
beta1_pow_out_(beta1_pow_out),
beta2_pow_out_(beta2_pow_out) {}
inline HOSTDEVICE void update_step(size_t i) const {
T beta1_pow_i = beta1_pow_[i];
T beta2_pow_i = beta2_pow_[i];
beta1_pow_out_[i] = beta1_pow_i * beta1_;
beta2_pow_out_[i] = beta2_pow_i * beta2_;
}
inline HOSTDEVICE void operator()(size_t i) const { update_step(i); }
inline HOSTDEVICE void apply_update(size_t limit) const {
for (size_t i = 0; i < limit; ++i) {
update_step(i);
}
}
};
template <typename T>
class AdamFunctor<T, GPUAdam> {
private:
......@@ -423,29 +390,20 @@ class AdamOpKernel : public framework::OpKernel<T> {
ctx.Attr<int64_t>("min_row_size_to_use_multithread");
bool lazy_mode = ctx.Attr<bool>("lazy_mode");
T epsilon = static_cast<T>(ctx.Attr<float>("epsilon"));
auto& param = Ref(ctx.Input<LoDTensor>("Param"), "Must set Param");
// auto& grad = Ref(ctx.Input<LoDTensor>("Grad"), "Must set Grad");
auto* param = ctx.Input<LoDTensor>("Param");
auto* grad_var = ctx.InputVar("Grad");
auto& mom1 = Ref(ctx.Input<LoDTensor>("Moment1"), "Must set Moment1");
auto& mom2 = Ref(ctx.Input<LoDTensor>("Moment2"), "Must set Moment2");
auto& lr =
Ref(ctx.Input<LoDTensor>("LearningRate"), "Must set LearningRate");
auto& beta1_pow =
Ref(ctx.Input<LoDTensor>("Beta1Pow"), "Must set Beta1Pow");
auto& beta2_pow =
Ref(ctx.Input<LoDTensor>("Beta2Pow"), "Must set Beta2Pow");
auto& param_out =
Ref(ctx.Output<LoDTensor>("ParamOut"), "Must set ParamOut");
auto& mom1_out =
Ref(ctx.Output<LoDTensor>("Moment1Out"), "Must set Moment1Out");
auto& mom2_out =
Ref(ctx.Output<LoDTensor>("Moment2Out"), "Must set Moment1Out");
auto& beta1_pow_out =
Ref(ctx.Output<LoDTensor>("Beta1PowOut"), "Must set Beta1PowOut");
auto& beta2_pow_out =
Ref(ctx.Output<LoDTensor>("Beta2PowOut"), "Must set Beta2PowOut");
auto* mom1 = ctx.Input<LoDTensor>("Moment1");
auto* mom2 = ctx.Input<LoDTensor>("Moment2");
auto* lr = ctx.Input<LoDTensor>("LearningRate");
auto* beta1_pow = ctx.Input<LoDTensor>("Beta1Pow");
auto* beta2_pow = ctx.Input<LoDTensor>("Beta2Pow");
auto* param_out = ctx.Output<LoDTensor>("ParamOut");
auto* mom1_out = ctx.Output<LoDTensor>("Moment1Out");
auto* mom2_out = ctx.Output<LoDTensor>("Moment2Out");
auto* beta1_pow_out = ctx.Output<LoDTensor>("Beta1PowOut");
auto* beta2_pow_out = ctx.Output<LoDTensor>("Beta2PowOut");
T beta1 = static_cast<T>(ctx.Attr<float>("beta1"));
if (ctx.HasInput("Beta1Tensor")) {
......@@ -457,60 +415,45 @@ class AdamOpKernel : public framework::OpKernel<T> {
auto* beta2_tensor = ctx.Input<framework::Tensor>("Beta2Tensor");
beta2 = static_cast<T>(GetAttrFromTensor(beta2_tensor));
}
VLOG(3) << "beta1_pow.numel() : " << beta1_pow.numel()
<< "beta2_pow.numel() : " << beta2_pow.numel();
VLOG(3) << "param.numel(): " << param.numel();
BetaPowFunctor<T> beta_functor(
beta1, beta2, beta1_pow.template data<T>(),
beta2_pow.template data<T>(),
beta1_pow_out.template mutable_data<T>(ctx.GetPlace()),
beta2_pow_out.template mutable_data<T>(ctx.GetPlace()));
VLOG(3) << "beta1_pow.numel() : " << beta1_pow->numel()
<< "beta2_pow.numel() : " << beta2_pow->numel();
VLOG(3) << "param.numel(): " << param->numel();
PADDLE_ENFORCE_EQ(beta1_pow_out->numel(), 1,
platform::errors::InvalidArgument(
"beta1 pow output size should be 1, but received "
"value is:%d.",
beta1_pow_out->numel()));
PADDLE_ENFORCE_EQ(beta2_pow_out->numel(), 1,
platform::errors::InvalidArgument(
"beta2 pow output size should be 1, but received "
"value is:%d.",
beta2_pow_out->numel()));
if (grad_var->IsType<framework::LoDTensor>()) {
auto& grad = Ref(ctx.Input<LoDTensor>("Grad"), "Must set Grad");
if (platform::is_cpu_place(ctx.GetPlace())) {
AdamFunctor<T, CPUAdam> functor(
beta1, beta2, epsilon, beta1_pow.template data<T>(),
beta2_pow.template data<T>(), mom1.template data<T>(),
mom1_out.template mutable_data<T>(ctx.GetPlace()),
mom2.template data<T>(),
mom2_out.template mutable_data<T>(ctx.GetPlace()),
lr.template data<T>(), grad.template data<T>(),
param.template data<T>(),
param_out.template mutable_data<T>(ctx.GetPlace()));
functor(param.numel());
beta_functor.apply_update(beta2_pow.numel());
} else if (platform::is_gpu_place(ctx.GetPlace())) {
AdamFunctor<T, GPUAdam> functor(
beta1, beta2, epsilon, beta1_pow.template data<T>(),
beta2_pow.template data<T>(), mom1.template data<T>(),
mom1_out.template mutable_data<T>(ctx.GetPlace()),
mom2.template data<T>(),
mom2_out.template mutable_data<T>(ctx.GetPlace()),
lr.template data<T>(), grad.template data<T>(),
param.template data<T>(),
param_out.template mutable_data<T>(ctx.GetPlace()));
// update param and moment
platform::ForRange<DeviceContext> for_range(
static_cast<const DeviceContext&>(ctx.device_context()),
param.numel());
for_range(functor);
// update beta1 and beta2
platform::ForRange<DeviceContext> for_range_beta(
static_cast<const DeviceContext&>(ctx.device_context()),
beta2_pow.numel());
for_range_beta(beta_functor);
}
auto* grad = ctx.Input<LoDTensor>("Grad");
AdamFunctor<T, CPUAdam> functor(
beta1, beta2, epsilon, beta1_pow->data<T>(), beta2_pow->data<T>(),
mom1->data<T>(), mom1_out->mutable_data<T>(ctx.GetPlace()),
mom2->data<T>(), mom2_out->mutable_data<T>(ctx.GetPlace()),
lr->data<T>(), grad->data<T>(), param->data<T>(),
param_out->mutable_data<T>(ctx.GetPlace()));
functor(param->numel());
beta1_pow_out->mutable_data<T>(ctx.GetPlace())[0] =
beta1 * beta1_pow->data<T>()[0];
beta2_pow_out->mutable_data<T>(ctx.GetPlace())[0] =
beta2 * beta2_pow->data<T>()[0];
} else if (grad_var->IsType<framework::SelectedRows>()) {
auto& grad =
Ref(ctx.Input<framework::SelectedRows>("Grad"), "Must set Grad");
if (grad.rows().size() == 0) {
auto* grad = ctx.Input<framework::SelectedRows>("Grad");
if (grad->rows().size() == 0) {
VLOG(3) << "grad row size is 0!!";
return;
}
std::vector<int64_t> cpu_rows(grad.rows().begin(), grad.rows().end());
std::vector<int64_t> cpu_rows(grad->rows().begin(), grad->rows().end());
bool is_strict_sorted = true;
for (size_t i = 1; i < cpu_rows.size(); ++i) {
if (cpu_rows[i - 1] >= cpu_rows[i]) {
......@@ -522,12 +465,12 @@ class AdamOpKernel : public framework::OpKernel<T> {
framework::SelectedRows tmp_grad_merge;
const framework::SelectedRows* grad_merge_ptr;
if (is_strict_sorted) {
grad_merge_ptr = &grad;
grad_merge_ptr = grad;
} else {
// merge duplicated rows if any.
// The rows of grad_merge have been sorted inside MergeAdd functor
scatter::MergeAdd<DeviceContext, T> merge_func;
merge_func(ctx.template device_context<DeviceContext>(), grad,
merge_func(ctx.template device_context<DeviceContext>(), *grad,
&tmp_grad_merge, true);
grad_merge_ptr = &tmp_grad_merge;
}
......@@ -538,112 +481,89 @@ class AdamOpKernel : public framework::OpKernel<T> {
const int64_t* rows = grad_merge.rows().Data(ctx.GetPlace());
auto row_numel = grad_tensor.numel() / grad_merge.rows().size();
if (platform::is_cpu_place(ctx.GetPlace())) {
SparseAdamFunctor<T, CPUAdam> functor(
beta1, beta2, epsilon, beta1_pow.template data<T>(),
beta2_pow.template data<T>(), mom1.template data<T>(),
mom1_out.template mutable_data<T>(ctx.GetPlace()),
mom2.template data<T>(),
mom2_out.template mutable_data<T>(ctx.GetPlace()),
lr.template data<T>(), grad_data, param.template data<T>(),
param_out.template mutable_data<T>(ctx.GetPlace()), rows, row_numel,
grad_merge.rows().size(), lazy_mode);
// update beta1 and beta2
beta_functor.apply_update(beta2_pow.numel());
if (lazy_mode) {
VLOG(3) << "run cpu lazy mode";
size_t row_count = grad_merge.rows().size();
std::vector<int64_t> cpu_rows(grad_merge.rows());
for (size_t row_index = 0; row_index < row_count; ++row_index) {
for (size_t offset = 0; offset < row_numel; ++offset) {
size_t i = cpu_rows[row_index] * row_numel + offset;
functor.adam_update(i, grad_data[row_index * row_numel + offset]);
}
SparseAdamFunctor<T, CPUAdam> functor(
beta1, beta2, epsilon, beta1_pow->data<T>(), beta2_pow->data<T>(),
mom1->data<T>(), mom1_out->mutable_data<T>(ctx.GetPlace()),
mom2->data<T>(), mom2_out->mutable_data<T>(ctx.GetPlace()),
lr->data<T>(), grad_data, param->data<T>(),
param_out->mutable_data<T>(ctx.GetPlace()), rows, row_numel,
grad_merge.rows().size(), lazy_mode);
// update beta1 and beta2
beta1_pow_out->mutable_data<T>(ctx.GetPlace())[0] =
beta1 * beta1_pow->data<T>()[0];
beta2_pow_out->mutable_data<T>(ctx.GetPlace())[0] =
beta2 * beta2_pow->data<T>()[0];
if (lazy_mode) {
VLOG(3) << "run cpu lazy mode";
size_t row_count = grad_merge.rows().size();
std::vector<int64_t> cpu_rows(grad_merge.rows());
for (size_t row_index = 0; row_index < row_count; ++row_index) {
for (size_t offset = 0; offset < row_numel; ++offset) {
size_t i = cpu_rows[row_index] * row_numel + offset;
functor.adam_update(i, grad_data[row_index * row_numel + offset]);
}
}
}
#ifndef _WIN32
else if (FLAGS_inner_op_parallelism > 1 && // NOLINT
min_row_size_to_use_multithread > 0 &&
param.dims()[0] > min_row_size_to_use_multithread) {
VLOG(3) << "use multi thread, inner_op_parallelism="
<< FLAGS_inner_op_parallelism
<< " min_row_size_to_use_multithread="
<< min_row_size_to_use_multithread;
if (FLAGS_inner_op_parallelism > 10) {
VLOG(1) << "FLAGS_inner_op_parallelism "
<< FLAGS_inner_op_parallelism << " is two large!";
}
auto& grad_rows = grad_merge.rows();
std::unordered_map<size_t, int> row_id_to_grad_row_offset;
size_t param_row_count = param.numel() / row_numel;
if (param_row_count < 1000) {
VLOG(1) << "param_row_count should be larger then 1000 to use "
"multi thread, currently "
<< param_row_count;
else if (FLAGS_inner_op_parallelism > 1 && // NOLINT
min_row_size_to_use_multithread > 0 &&
param->dims()[0] > min_row_size_to_use_multithread) {
VLOG(3) << "use multi thread, inner_op_parallelism="
<< FLAGS_inner_op_parallelism
<< " min_row_size_to_use_multithread="
<< min_row_size_to_use_multithread;
if (FLAGS_inner_op_parallelism > 10) {
VLOG(1) << "FLAGS_inner_op_parallelism " << FLAGS_inner_op_parallelism
<< " is two large!";
}
auto& grad_rows = grad_merge.rows();
std::unordered_map<size_t, int> row_id_to_grad_row_offset;
size_t param_row_count = param->numel() / row_numel;
if (param_row_count < 1000) {
VLOG(1) << "param_row_count should be larger then 1000 to use "
"multi thread, currently "
<< param_row_count;
}
for (size_t i = 0; i < grad_rows.size(); ++i) {
row_id_to_grad_row_offset[grad_rows[i]] = i;
}
std::vector<std::future<void>> fs;
int64_t line_in_each_thread =
param_row_count / FLAGS_inner_op_parallelism + 1;
for (int i = 0; i < FLAGS_inner_op_parallelism; ++i) {
int64_t start = i * line_in_each_thread;
int64_t end = (i + 1) * line_in_each_thread;
if (start >= static_cast<int64_t>(param_row_count)) {
break;
}
for (size_t i = 0; i < grad_rows.size(); ++i) {
row_id_to_grad_row_offset[grad_rows[i]] = i;
if (end > static_cast<int64_t>(param_row_count)) {
end = static_cast<int64_t>(param_row_count);
}
std::vector<std::future<void>> fs;
int64_t line_in_each_thread =
param_row_count / FLAGS_inner_op_parallelism + 1;
for (int i = 0; i < FLAGS_inner_op_parallelism; ++i) {
int64_t start = i * line_in_each_thread;
int64_t end = (i + 1) * line_in_each_thread;
if (start >= static_cast<int64_t>(param_row_count)) {
break;
fs.push_back(framework::Async([&functor, &row_id_to_grad_row_offset,
&grad_data, row_numel, start, end]() {
for (int64_t row_id = start; row_id < end; ++row_id) {
auto iter = row_id_to_grad_row_offset.find(row_id);
if (iter != row_id_to_grad_row_offset.end()) {
for (size_t row_offset = 0U; row_offset < row_numel;
++row_offset) {
functor.adam_update(
row_id * row_numel + row_offset,
grad_data[iter->second * row_numel + row_offset]);
}
} else {
for (size_t row_offset = 0U; row_offset < row_numel;
++row_offset) {
functor.adam_update(row_id * row_numel + row_offset, 0);
}
}
}
if (end > static_cast<int64_t>(param_row_count)) {
end = static_cast<int64_t>(param_row_count);
}
fs.push_back(
framework::Async([&functor, &row_id_to_grad_row_offset,
&grad_data, row_numel, start, end]() {
for (int64_t row_id = start; row_id < end; ++row_id) {
auto iter = row_id_to_grad_row_offset.find(row_id);
if (iter != row_id_to_grad_row_offset.end()) {
for (size_t row_offset = 0U; row_offset < row_numel;
++row_offset) {
functor.adam_update(
row_id * row_numel + row_offset,
grad_data[iter->second * row_numel + row_offset]);
}
} else {
for (size_t row_offset = 0U; row_offset < row_numel;
++row_offset) {
functor.adam_update(row_id * row_numel + row_offset, 0);
}
}
}
}));
}
for (size_t i = 0; i < fs.size(); ++i) fs[i].wait();
}));
}
#endif // !_WIN32
else { // NOLINT
functor(param.numel());
}
} else if (platform::is_gpu_place(ctx.GetPlace())) {
SparseAdamFunctor<T, GPUAdam> functor(
beta1, beta2, epsilon, beta1_pow.template data<T>(),
beta2_pow.template data<T>(), mom1.template data<T>(),
mom1_out.template mutable_data<T>(ctx.GetPlace()),
mom2.template data<T>(),
mom2_out.template mutable_data<T>(ctx.GetPlace()),
lr.template data<T>(), grad_data, param.template data<T>(),
param_out.template mutable_data<T>(ctx.GetPlace()), rows, row_numel,
grad_merge.rows().size(), lazy_mode);
// FIXME(minqiyang): remove BinarySearch in GPU later
platform::ForRange<DeviceContext> for_range(
static_cast<const DeviceContext&>(ctx.device_context()),
param.numel());
for_range(functor);
// update beta1 and beta2
platform::ForRange<DeviceContext> for_range_beta(
static_cast<const DeviceContext&>(ctx.device_context()),
beta2_pow.numel());
for_range_beta(beta_functor);
for (size_t i = 0; i < fs.size(); ++i) fs[i].wait();
}
#endif // !_WIN32
else { // NOLINT
functor(param->numel());
}
} else {
PADDLE_THROW("Variable type not supported by adam_op");
......
......@@ -187,12 +187,21 @@ void SetTensorFromPyArrayT(
}
} else {
#ifdef PADDLE_WITH_CUDA
auto dst = self->mutable_data<T>(place);
T *dst;
if (array.nbytes() <= 4 && !paddle::platform::is_cuda_pinned_place(place)) {
dst = self->mutable_data<T>(platform::CPUPlace());
} else {
dst = self->mutable_data<T>(place);
}
if (paddle::platform::is_cuda_pinned_place(place)) {
std::memcpy(dst, array.data(), array.nbytes());
} else if (paddle::platform::is_gpu_place(place)) {
paddle::platform::GpuMemcpySync(dst, array.data(), array.nbytes(),
cudaMemcpyHostToDevice);
if (array.nbytes() <= 4) {
std::memcpy(dst, array.data(), array.nbytes());
} else {
paddle::platform::GpuMemcpySync(dst, array.data(), array.nbytes(),
cudaMemcpyHostToDevice);
}
} else {
PADDLE_THROW(
"Incompatible place type: Tensor.set() supports CPUPlace, CUDAPlace "
......
......@@ -404,7 +404,8 @@ class Optimizer(object):
dtype=None,
fill_value=0.0,
shape=None,
type=None):
type=None,
force_cpu=False):
"""Utility function to add an accumulator for a parameter
Args:
......@@ -438,7 +439,9 @@ class Optimizer(object):
shape=shape,
belong_to_optimizer=True)
self.helper.set_variable_initializer(
var, initializer=Constant(value=float(fill_value)))
var,
initializer=Constant(
value=float(fill_value), force_cpu=force_cpu))
if framework.in_dygraph_mode():
if len(self._accumulators_holder) > 0:
......@@ -1790,14 +1793,14 @@ class AdamOptimizer(Optimizer):
fill_value=0.9 if isinstance(self._beta1, Variable) \
else self._beta1,
shape=[1],
type=core.VarDesc.VarType.LOD_TENSOR)
type=core.VarDesc.VarType.LOD_TENSOR, force_cpu=True)
self._add_accumulator(
name=self._beta2_pow_acc_str,
param=p,
fill_value=0.999 if isinstance(self._beta2, Variable) \
else self._beta2,
shape=[1],
type=core.VarDesc.VarType.LOD_TENSOR)
type=core.VarDesc.VarType.LOD_TENSOR, force_cpu=True)
def _append_optimize_op(self, block, param_and_grad):
assert isinstance(block, framework.Block)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册