未验证 提交 0bd7f97b 编写于 作者: 武毅 提交者: GitHub

Merge pull request #7045 from typhoonzero/adam_selectedrows

Adam selectedrows and scatter functors
...@@ -105,48 +105,18 @@ struct SparseAdagradFunctor<platform::CPUDeviceContext, T> { ...@@ -105,48 +105,18 @@ struct SparseAdagradFunctor<platform::CPUDeviceContext, T> {
const framework::Tensor& learning_rate, T epsilon, const framework::Tensor& learning_rate, T epsilon,
framework::Tensor* moment, framework::Tensor* param) { framework::Tensor* moment, framework::Tensor* param) {
// 1. g_m.rows = set(g.rows) // 1. g_m.rows = set(g.rows)
auto grad_rows = grad.rows();
std::set<int64_t> row_set(grad_rows.begin(), grad_rows.end());
std::vector<int64_t> merge_rows(row_set.begin(), row_set.end());
auto grad_width = grad.value().dims()[1]; auto grad_width = grad.value().dims()[1];
std::unique_ptr<framework::SelectedRows> grad_merge{ math::scatter::MergeAdd<platform::CPUDeviceContext, T> merge_func;
new framework::SelectedRows()}; auto grad_merge = merge_func(context, grad);
grad_merge->set_rows(merge_rows); auto& merge_rows = grad_merge.rows();
grad_merge->set_height(grad.height()); auto* grad_merge_data = grad_merge.mutable_value()->template data<T>();
grad_merge->mutable_value()->mutable_data<T>(
framework::make_ddim(
{static_cast<int64_t>(merge_rows.size()), grad_width}),
context.GetPlace());
math::SetConstant<platform::CPUDeviceContext, T> constant_functor;
constant_functor(context, grad_merge->mutable_value(), 0.0);
auto* grad_merge_data = grad_merge->mutable_value()->data<T>();
auto* grad_data = grad.value().data<T>();
for (size_t i = 0; i < grad_rows.size(); i++) {
size_t grad_merge_i = FindPos(merge_rows, grad_rows[i]);
for (int64_t j = 0; j < grad_width; j++) {
grad_merge_data[grad_merge_i * grad_width + j] +=
grad_data[i * grad_width + j];
}
}
// 2. m += g_m * g_m // 2. m += g_m * g_m
std::unique_ptr<framework::SelectedRows> grad_square{ math::scatter::Mul<platform::CPUDeviceContext, T> sqare_func;
new framework::SelectedRows()}; auto grad_square = sqare_func(context, grad_merge, grad_merge);
grad_square->set_rows(grad_merge->rows());
grad_square->set_height(grad_merge->height());
grad_square->mutable_value()->mutable_data<T>(grad_merge->value().dims(),
context.GetPlace());
auto gs =
framework::EigenVector<T>::Flatten(*(grad_square->mutable_value()));
auto gm = framework::EigenVector<T>::Flatten(grad_merge->value());
gs.device(*context.eigen_device()) = gm * gm;
math::SelectedRowsAddToTensor<platform::CPUDeviceContext, T> functor; math::SelectedRowsAddToTensor<platform::CPUDeviceContext, T> functor;
functor(context, *grad_square, moment); functor(context, grad_square, moment);
// 3. update parameter // 3. update parameter
auto* lr = learning_rate.data<T>(); auto* lr = learning_rate.data<T>();
......
...@@ -78,62 +78,30 @@ struct SparseAdagradFunctor<platform::CUDADeviceContext, T> { ...@@ -78,62 +78,30 @@ struct SparseAdagradFunctor<platform::CUDADeviceContext, T> {
const framework::Tensor& learning_rate, T epsilon, const framework::Tensor& learning_rate, T epsilon,
framework::Tensor* moment, framework::Tensor* param) { framework::Tensor* moment, framework::Tensor* param) {
// 1. g_m.rows = set(g.rows) // 1. g_m.rows = set(g.rows)
auto grad_rows = grad.rows();
std::set<int64_t> row_set(grad_rows.begin(), grad_rows.end());
std::vector<int64_t> merge_rows(row_set.begin(), row_set.end());
auto grad_width = grad.value().dims()[1]; auto grad_width = grad.value().dims()[1];
std::unique_ptr<framework::SelectedRows> grad_merge{ math::scatter::MergeAdd<platform::CUDADeviceContext, T> merge_func;
new framework::SelectedRows()}; auto grad_merge = merge_func(context, grad);
grad_merge->set_rows(merge_rows); auto* grad_merge_data = grad_merge.mutable_value()->template data<T>();
grad_merge->set_height(grad.height()); auto& merge_rows = grad_merge.rows();
grad_merge->mutable_value()->mutable_data<T>(
framework::make_ddim(
{static_cast<int64_t>(merge_rows.size()), grad_width}),
context.GetPlace());
math::SetConstant<platform::CUDADeviceContext, T> constant_functor;
constant_functor(context, grad_merge->mutable_value(), 0.0);
auto* grad_merge_data = grad_merge->mutable_value()->data<T>();
auto* grad_data = grad.value().data<T>();
const int block_size = 256;
dim3 threads(block_size, 1);
dim3 grid1(1, grad_rows.size());
MergeGradKernel<
T, 256><<<grid1, threads, 0,
reinterpret_cast<const platform::CUDADeviceContext&>(context)
.stream()>>>(grad_data, grad.rows().data(),
grad_merge_data, grad_merge->rows().data(),
grad_merge->rows().size(), grad_width);
// 2. m += g_m * g_m // 2. m += g_m * g_m
std::unique_ptr<framework::SelectedRows> grad_square{ math::scatter::Mul<platform::CUDADeviceContext, T> sqare_func;
new framework::SelectedRows()}; auto grad_square = sqare_func(context, grad_merge, grad_merge);
grad_square->set_rows(grad_merge->rows());
grad_square->set_height(grad_merge->height());
grad_square->mutable_value()->mutable_data<T>(grad_merge->value().dims(),
context.GetPlace());
auto gs =
framework::EigenVector<T>::Flatten(*(grad_square->mutable_value()));
auto gm = framework::EigenVector<T>::Flatten(grad_merge->value());
gs.device(*context.eigen_device()) = gm * gm;
math::SelectedRowsAddToTensor<platform::CUDADeviceContext, T> functor; math::SelectedRowsAddToTensor<platform::CUDADeviceContext, T> functor;
functor(context, *grad_square, moment); functor(context, grad_square, moment);
// 3. update parameter // 3. update parameter
auto* lr = learning_rate.data<T>(); auto* lr = learning_rate.data<T>();
auto* param_data = param->data<T>(); auto* param_data = param->data<T>();
auto* moment_data = moment->data<T>(); auto* moment_data = moment->data<T>();
const int block_size = 256;
dim3 threads(block_size, 1);
dim3 grid2(1, merge_rows.size()); dim3 grid2(1, merge_rows.size());
SparseAdagradFunctorKernel< SparseAdagradFunctorKernel<
T, 256><<<grid2, threads, 0, T, 256><<<grid2, threads, 0,
reinterpret_cast<const platform::CUDADeviceContext&>(context) reinterpret_cast<const platform::CUDADeviceContext&>(context)
.stream()>>>(grad_merge_data, grad_merge->rows().data(), .stream()>>>(grad_merge_data, grad_merge.rows().data(),
lr, param_data, moment_data, grad_width, lr, param_data, moment_data, grad_width,
epsilon); epsilon);
} }
......
...@@ -16,11 +16,14 @@ limitations under the License. */ ...@@ -16,11 +16,14 @@ limitations under the License. */
#include <math.h> // for sqrt in CPU and CUDA #include <math.h> // for sqrt in CPU and CUDA
#include "paddle/framework/op_registry.h" #include "paddle/framework/op_registry.h"
#include "paddle/operators/detail/safe_ref.h" #include "paddle/operators/detail/safe_ref.h"
#include "paddle/operators/math/selected_rows_functor.h"
#include "paddle/platform/for_range.h" #include "paddle/platform/for_range.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
namespace scatter = paddle::operators::math::scatter;
template <typename T> template <typename T>
struct AdamFunctor { struct AdamFunctor {
T beta1_; T beta1_;
...@@ -79,6 +82,69 @@ struct AdamFunctor { ...@@ -79,6 +82,69 @@ struct AdamFunctor {
} }
}; };
template <typename T>
struct SparseAdamFunctor {
T beta1_;
T beta2_;
T epsilon_;
const T* beta1_pow_;
const T* beta2_pow_;
const T* moment1_;
T* moment1_out_;
const T* moment2_;
T* moment2_out_;
const T* lr_;
const T* grad_;
const T* param_;
T* param_out_;
const int64_t* rows_;
int64_t row_numel_;
SparseAdamFunctor(T beta1, T beta2, T epsilon, const T* beta1_pow,
const T* beta2_pow, const T* mom1, T* mom1_out,
const T* mom2, T* mom2_out, const T* lr, const T* grad,
const T* param, T* param_out, const int64_t* rows,
int64_t row_numel)
: beta1_(beta1),
beta2_(beta2),
epsilon_(epsilon),
beta1_pow_(beta1_pow),
beta2_pow_(beta2_pow),
moment1_(mom1),
moment1_out_(mom1_out),
moment2_(mom2),
moment2_out_(mom2_out),
lr_(lr),
grad_(grad),
param_(param),
param_out_(param_out),
rows_(rows),
row_numel_(row_numel) {}
inline HOSTDEVICE void operator()(size_t i) const {
T beta1_pow = *beta1_pow_;
T beta2_pow = *beta2_pow_;
for (int64_t j = 0; j < row_numel_; ++j) {
T g = grad_[i * row_numel_ + j];
T mom1 = moment1_[rows_[i] * row_numel_ + j];
T mom2 = moment2_[rows_[i] * row_numel_ + j];
T lr = *lr_;
T p = param_[rows_[i] * row_numel_ + j];
lr *= sqrt(1 - beta2_pow) / (1 - beta1_pow);
mom1 = beta1_ * mom1 + (1 - beta1_) * g;
mom2 = beta2_ * mom2 + (1 - beta2_) * g * g;
p -= lr * (mom1 / (sqrt(mom2) + epsilon_));
moment1_out_[rows_[i] * row_numel_ + j] = mom1;
moment2_out_[rows_[i] * row_numel_ + j] = mom2;
param_out_[rows_[i] * row_numel_ + j] = p;
} // for col id
}
};
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
class AdamOpKernel : public framework::OpKernel<T> { class AdamOpKernel : public framework::OpKernel<T> {
public: public:
...@@ -90,7 +156,8 @@ class AdamOpKernel : public framework::OpKernel<T> { ...@@ -90,7 +156,8 @@ class AdamOpKernel : public framework::OpKernel<T> {
T beta2 = static_cast<T>(ctx.Attr<float>("beta2")); T beta2 = static_cast<T>(ctx.Attr<float>("beta2"));
T epsilon = static_cast<T>(ctx.Attr<float>("epsilon")); T epsilon = static_cast<T>(ctx.Attr<float>("epsilon"));
auto& param = Ref(ctx.Input<LoDTensor>("Param"), "Must set Param"); auto& param = Ref(ctx.Input<LoDTensor>("Param"), "Must set Param");
auto& grad = Ref(ctx.Input<LoDTensor>("Grad"), "Must set Grad"); // auto& grad = Ref(ctx.Input<LoDTensor>("Grad"), "Must set Grad");
auto* grad_var = ctx.InputVar("Grad");
auto& mom1 = Ref(ctx.Input<LoDTensor>("Moment1"), "Must set Moment1"); auto& mom1 = Ref(ctx.Input<LoDTensor>("Moment1"), "Must set Moment1");
auto& mom2 = Ref(ctx.Input<LoDTensor>("Moment2"), "Must set Moment2"); auto& mom2 = Ref(ctx.Input<LoDTensor>("Moment2"), "Must set Moment2");
auto& lr = auto& lr =
...@@ -108,18 +175,48 @@ class AdamOpKernel : public framework::OpKernel<T> { ...@@ -108,18 +175,48 @@ class AdamOpKernel : public framework::OpKernel<T> {
auto& mom2_out = auto& mom2_out =
Ref(ctx.Output<LoDTensor>("Moment2Out"), "Must set Moment1Out"); Ref(ctx.Output<LoDTensor>("Moment2Out"), "Must set Moment1Out");
AdamFunctor<T> functor(beta1, beta2, epsilon, beta1_pow.template data<T>(), if (grad_var->IsType<framework::LoDTensor>()) {
beta2_pow.template data<T>(), auto& grad = Ref(ctx.Input<LoDTensor>("Grad"), "Must set Grad");
mom1.template data<T>(), AdamFunctor<T> functor(
mom1_out.template mutable_data<T>(ctx.GetPlace()), beta1, beta2, epsilon, beta1_pow.template data<T>(),
mom2.template data<T>(), beta2_pow.template data<T>(), mom1.template data<T>(),
mom2_out.template mutable_data<T>(ctx.GetPlace()), mom1_out.template mutable_data<T>(ctx.GetPlace()),
lr.template data<T>(), grad.template data<T>(), mom2.template data<T>(),
param.template data<T>(), mom2_out.template mutable_data<T>(ctx.GetPlace()),
param_out.template mutable_data<T>(ctx.GetPlace())); lr.template data<T>(), grad.template data<T>(),
platform::ForRange<DeviceContext> for_range( param.template data<T>(),
static_cast<const DeviceContext&>(ctx.device_context()), param.numel()); param_out.template mutable_data<T>(ctx.GetPlace()));
for_range(functor); platform::ForRange<DeviceContext> for_range(
static_cast<const DeviceContext&>(ctx.device_context()),
param.numel());
for_range(functor);
} else if (grad_var->IsType<framework::SelectedRows>()) {
auto& grad =
Ref(ctx.Input<framework::SelectedRows>("Grad"), "Must set Grad");
// merge duplicated rows if any.
scatter::MergeAdd<DeviceContext, T> merge_func;
auto grad_merge =
merge_func(ctx.template device_context<DeviceContext>(), grad);
auto& grad_tensor = grad_merge.value();
const T* grad_data = grad_tensor.template data<T>();
auto* rows = grad_merge.rows().data();
auto row_numel = grad_tensor.numel() / grad_merge.rows().size();
SparseAdamFunctor<T> functor(
beta1, beta2, epsilon, beta1_pow.template data<T>(),
beta2_pow.template data<T>(), mom1.template data<T>(),
mom1_out.template mutable_data<T>(ctx.GetPlace()),
mom2.template data<T>(),
mom2_out.template mutable_data<T>(ctx.GetPlace()),
lr.template data<T>(), grad_data, param.template data<T>(),
param_out.template mutable_data<T>(ctx.GetPlace()), rows, row_numel);
platform::ForRange<DeviceContext> for_range(
static_cast<const DeviceContext&>(ctx.device_context()),
grad_merge.rows().size());
for_range(functor);
} else {
PADDLE_THROW("Variable type not supported by adam_op");
}
} }
}; };
......
...@@ -12,8 +12,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,8 +12,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/operators/math/selected_rows_functor.h" #include <set>
#include "paddle/operators/math/math_function.h" #include "paddle/operators/math/math_function.h"
#include "paddle/operators/math/selected_rows_functor.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -179,6 +181,118 @@ template struct SelectedRowsAddToTensor<platform::CPUDeviceContext, double>; ...@@ -179,6 +181,118 @@ template struct SelectedRowsAddToTensor<platform::CPUDeviceContext, double>;
template struct SelectedRowsAddToTensor<platform::CPUDeviceContext, int>; template struct SelectedRowsAddToTensor<platform::CPUDeviceContext, int>;
template struct SelectedRowsAddToTensor<platform::CPUDeviceContext, int64_t>; template struct SelectedRowsAddToTensor<platform::CPUDeviceContext, int64_t>;
// This is a separated namespace for manipulate SelectedRows typed
// data. Like merge duplicated rows, adding two SelectedRows etc.
//
// Another group of functors is called "scatter updates", which means
// use SelectedRows to update a dense tensor with different Ops, like
// add or mul.
namespace scatter {
size_t FindPos(const std::vector<int64_t>& rows, int64_t value) {
return std::find(rows.begin(), rows.end(), value) - rows.begin();
}
template <typename T>
struct MergeAdd<platform::CPUDeviceContext, T> {
framework::SelectedRows operator()(const platform::CPUDeviceContext& context,
const framework::SelectedRows& input) {
framework::SelectedRows out;
auto input_rows = input.rows();
std::set<int64_t> row_set(input_rows.begin(), input_rows.end());
std::vector<int64_t> merge_rows(row_set.begin(), row_set.end());
auto input_width = input.value().dims()[1];
out.set_rows(merge_rows);
out.set_height(input.height());
out.mutable_value()->mutable_data<T>(
framework::make_ddim(
{static_cast<int64_t>(merge_rows.size()), input_width}),
context.GetPlace());
math::SetConstant<platform::CPUDeviceContext, T> constant_functor;
constant_functor(context, out.mutable_value(), 0.0);
auto* out_data = out.mutable_value()->data<T>();
auto* input_data = input.value().data<T>();
for (size_t i = 0; i < input_rows.size(); i++) {
size_t out_i = FindPos(merge_rows, input_rows[i]);
for (int64_t j = 0; j < input_width; j++) {
out_data[out_i * input_width + j] += input_data[i * input_width + j];
}
}
return out;
}
};
template struct MergeAdd<platform::CPUDeviceContext, float>;
template struct MergeAdd<platform::CPUDeviceContext, double>;
template struct MergeAdd<platform::CPUDeviceContext, int>;
template struct MergeAdd<platform::CPUDeviceContext, int64_t>;
template <typename T>
struct UpdateToTensor<platform::CPUDeviceContext, T> {
void operator()(const platform::CPUDeviceContext& context,
const ScatterOps& op, const framework::SelectedRows& input1,
framework::Tensor* input2) {
auto in1_height = input1.height();
auto in2_dims = input2->dims();
PADDLE_ENFORCE_EQ(in1_height, in2_dims[0]);
auto& in1_value = input1.value();
auto& in1_rows = input1.rows();
int64_t in1_row_numel = in1_value.numel() / in1_rows.size();
PADDLE_ENFORCE_EQ(in1_row_numel, input2->numel() / in1_height);
auto* in1_data = in1_value.data<T>();
auto* input2_data = input2->data<T>();
// FIXME(typhoonzero): use macro fix the below messy code.
switch (op) {
case ScatterOps::ASSIGN:
INLINE_FOR2(in1_rows.size(), in1_row_numel)
input2_data[in1_rows[i] * in1_row_numel + j] =
in1_data[i * in1_row_numel + j];
break;
case ScatterOps::ADD:
INLINE_FOR2(in1_rows.size(), in1_row_numel)
input2_data[in1_rows[i] * in1_row_numel + j] +=
in1_data[i * in1_row_numel + j];
break;
case ScatterOps::SUB:
INLINE_FOR2(in1_rows.size(), in1_row_numel)
input2_data[in1_rows[i] * in1_row_numel + j] -=
in1_data[i * in1_row_numel + j];
break;
case ScatterOps::SUBBY:
INLINE_FOR2(in1_rows.size(), in1_row_numel)
input2_data[in1_rows[i] * in1_row_numel + j] =
in1_data[i * in1_row_numel + j] -
input2_data[in1_rows[i] * in1_row_numel + j];
break;
case ScatterOps::MUL:
INLINE_FOR2(in1_rows.size(), in1_row_numel)
input2_data[in1_rows[i] * in1_row_numel + j] *=
in1_data[i * in1_row_numel + j];
break;
case ScatterOps::DIV:
INLINE_FOR2(in1_rows.size(), in1_row_numel)
input2_data[in1_rows[i] * in1_row_numel + j] /=
in1_data[i * in1_row_numel + j];
break;
case ScatterOps::DIVBY:
INLINE_FOR2(in1_rows.size(), in1_row_numel)
input2_data[in1_rows[i] * in1_row_numel + j] =
in1_data[i * in1_row_numel + j] /
input2_data[in1_rows[i] * in1_row_numel + j];
break;
}
}
};
} // namespace scatter
} // namespace math } // namespace math
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -12,6 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,6 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include <set>
#include "paddle/operators/math/math_function.h" #include "paddle/operators/math/math_function.h"
#include "paddle/operators/math/selected_rows_functor.h" #include "paddle/operators/math/selected_rows_functor.h"
#include "paddle/platform/cuda_helper.h" #include "paddle/platform/cuda_helper.h"
...@@ -222,6 +224,157 @@ template struct SelectedRowsAddToTensor<platform::CUDADeviceContext, float>; ...@@ -222,6 +224,157 @@ template struct SelectedRowsAddToTensor<platform::CUDADeviceContext, float>;
template struct SelectedRowsAddToTensor<platform::CUDADeviceContext, double>; template struct SelectedRowsAddToTensor<platform::CUDADeviceContext, double>;
template struct SelectedRowsAddToTensor<platform::CUDADeviceContext, int>; template struct SelectedRowsAddToTensor<platform::CUDADeviceContext, int>;
template struct SelectedRowsAddToTensor<platform::CUDADeviceContext, int64_t>; template struct SelectedRowsAddToTensor<platform::CUDADeviceContext, int64_t>;
namespace scatter {
template <typename T, int block_size>
__global__ void MergeAddKernel(const T* input, const int64_t* input_rows,
T* out, const int64_t* out_rows,
size_t out_rows_size, int64_t row_numel) {
const int ty = blockIdx.y;
int tid = threadIdx.x;
__shared__ size_t out_idx;
if (tid == 0) {
for (size_t i = 0; i < out_rows_size; i++) {
if (input_rows[ty] == out_rows[i]) {
out_idx = i;
}
}
}
__syncthreads();
input += ty * row_numel;
out += out_idx * row_numel;
for (int index = tid; index < row_numel; index += block_size) {
paddle::platform::CudaAtomicAdd(out + index, input[index]);
}
}
template <typename T>
struct MergeAdd<platform::CUDADeviceContext, T> {
framework::SelectedRows operator()(const platform::CUDADeviceContext& context,
const framework::SelectedRows& input) {
framework::SelectedRows out;
auto input_rows = input.rows();
std::set<int64_t> row_set(input_rows.begin(), input_rows.end());
std::vector<int64_t> merge_rows(row_set.begin(), row_set.end());
auto input_width = input.value().dims()[1];
out.set_rows(merge_rows);
out.set_height(input.height());
out.mutable_value()->mutable_data<T>(
framework::make_ddim(
{static_cast<int64_t>(merge_rows.size()), input_width}),
context.GetPlace());
math::SetConstant<platform::CUDADeviceContext, T> constant_functor;
constant_functor(context, out.mutable_value(), 0.0);
auto* out_data = out.mutable_value()->data<T>();
auto* input_data = input.value().data<T>();
const int block_size = 256;
dim3 threads(block_size, 1);
dim3 grid1(1, input_rows.size());
MergeAddKernel<
T, 256><<<grid1, threads, 0,
reinterpret_cast<const platform::CUDADeviceContext&>(context)
.stream()>>>(input_data, input.rows().data(), out_data,
out.rows().data(), out.rows().size(),
input_width);
return out;
}
};
template struct MergeAdd<platform::CUDADeviceContext, float>;
template struct MergeAdd<platform::CUDADeviceContext, double>;
template struct MergeAdd<platform::CUDADeviceContext, int>;
template struct MergeAdd<platform::CUDADeviceContext, int64_t>;
template <typename T, int block_size>
__global__ void UpdateToTensorKernel(const T* selected_rows,
const int64_t* rows, const ScatterOps& op,
T* tensor_out, int64_t row_numel) {
const int ty = blockIdx.y;
int tid = threadIdx.x;
selected_rows += ty * row_numel;
tensor_out += rows[ty] * row_numel;
// FIXME(typhoonzero): use macro fix the below messy code.
switch (op) {
case ScatterOps::ASSIGN:
for (int index = tid; index < row_numel; index += block_size) {
tensor_out[index] = selected_rows[index];
}
break;
case ScatterOps::ADD:
for (int index = tid; index < row_numel; index += block_size) {
tensor_out[index] += selected_rows[index];
}
break;
case ScatterOps::SUB:
for (int index = tid; index < row_numel; index += block_size) {
tensor_out[index] -= selected_rows[index];
}
break;
case ScatterOps::SUBBY:
for (int index = tid; index < row_numel; index += block_size) {
tensor_out[index] = selected_rows[index] - tensor_out[index];
}
break;
case ScatterOps::MUL:
for (int index = tid; index < row_numel; index += block_size) {
tensor_out[index] *= selected_rows[index];
}
break;
case ScatterOps::DIV:
for (int index = tid; index < row_numel; index += block_size) {
tensor_out[index] /= selected_rows[index];
}
break;
case ScatterOps::DIVBY:
for (int index = tid; index < row_numel; index += block_size) {
tensor_out[index] = selected_rows[index] / tensor_out[index];
}
break;
}
}
template <typename T>
struct UpdateToTensor<platform::CUDADeviceContext, T> {
void operator()(const platform::CUDADeviceContext& context,
const ScatterOps& op, const framework::SelectedRows& input1,
framework::Tensor* input2) {
// NOTE: Use SelectedRowsAddToTensor for better performance
// no additional MergeAdd called.
MergeAdd<platform::CUDADeviceContext, T> merge_func;
auto merged_in1 = merge_func(context, input1);
auto in1_height = merged_in1.height();
auto in2_dims = input2->dims();
PADDLE_ENFORCE_EQ(in1_height, in2_dims[0]);
auto& in1_value = merged_in1.value();
auto& in1_rows = merged_in1.rows();
int64_t in1_row_numel = in1_value.numel() / in1_rows.size();
PADDLE_ENFORCE_EQ(in1_row_numel, input2->numel() / in1_height);
auto* in1_data = in1_value.template data<T>();
auto* in2_data = input2->data<T>();
dim3 threads(platform::PADDLE_CUDA_NUM_THREADS, 1);
dim3 grid(1, in1_rows.size());
UpdateToTensorKernel<T, platform::PADDLE_CUDA_NUM_THREADS><<<
grid, threads, 0, context.stream()>>>(in1_data, in1_rows.data(), op,
in2_data, in1_row_numel);
}
};
} // namespace scatter
} // namespace math } // namespace math
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -12,9 +12,14 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,9 +12,14 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#pragma once #pragma once
#include "paddle/framework/eigen.h"
#include "paddle/framework/selected_rows.h" #include "paddle/framework/selected_rows.h"
#include "paddle/platform/device_context.h" #include "paddle/platform/device_context.h"
#define INLINE_FOR2(sizei, sizej) \
for (int64_t i = 0; i < sizei; i++) \
for (int64_t j = 0; j < sizej; j++)
namespace paddle { namespace paddle {
namespace operators { namespace operators {
namespace math { namespace math {
...@@ -52,6 +57,78 @@ struct SelectedRowsAddToTensor { ...@@ -52,6 +57,78 @@ struct SelectedRowsAddToTensor {
framework::Tensor* input2); framework::Tensor* input2);
}; };
namespace scatter {
// functors for manuplating SelectedRows data
template <typename DeviceContext, typename T>
struct MergeAdd {
// unary functor, merge by adding duplicated rows in
// the input SelectedRows object.
framework::SelectedRows operator()(const DeviceContext& context,
const framework::SelectedRows& input);
};
template <typename DeviceContext, typename T>
struct Add {
framework::SelectedRows operator()(const DeviceContext& context,
const framework::SelectedRows& input1,
const framework::SelectedRows& input2) {
framework::SelectedRows out;
out.set_rows(input1.rows());
out.set_height(input1.height());
out.mutable_value()->mutable_data<T>(input1.value().dims(),
context.GetPlace());
auto e_out = framework::EigenVector<T>::Flatten(*(out.mutable_value()));
auto e_in1 = framework::EigenVector<T>::Flatten(input1.value());
auto e_in2 = framework::EigenVector<T>::Flatten(input2.value());
e_out.device(*context.eigen_device()) = e_in1 + e_in2;
return out;
}
};
template <typename DeviceContext, typename T>
struct Mul {
// multiply two SelectedRows
framework::SelectedRows operator()(const DeviceContext& context,
const framework::SelectedRows& input1,
const framework::SelectedRows& input2) {
framework::SelectedRows out;
out.set_rows(input1.rows());
out.set_height(input1.height());
out.mutable_value()->mutable_data<T>(input1.value().dims(),
context.GetPlace());
auto e_out = framework::EigenVector<T>::Flatten(*(out.mutable_value()));
auto e_in1 = framework::EigenVector<T>::Flatten(input1.value());
auto e_in2 = framework::EigenVector<T>::Flatten(input2.value());
e_out.device(*context.eigen_device()) = e_in1 * e_in2;
return out;
}
// multiply scalar to SelectedRows
framework::SelectedRows operator()(const DeviceContext& context,
const framework::SelectedRows& input1,
const T input2) {
framework::SelectedRows out;
out.set_rows(input1.rows());
out.set_height(input1.height());
out.mutable_value()->mutable_data<T>(input1.value().dims(),
context.GetPlace());
auto e_out = framework::EigenVector<T>::Flatten(*(out.mutable_value()));
auto e_in1 = framework::EigenVector<T>::Flatten(input1.value());
e_out.device(*context.eigen_device()) = input2 * e_in1;
return out;
}
};
enum class ScatterOps { ASSIGN, ADD, SUB, SUBBY, MUL, DIV, DIVBY };
// out = seleted_rows_in / tensor
template <typename DeviceContext, typename T>
struct UpdateToTensor {
void operator()(const DeviceContext& context, const ScatterOps& op,
const framework::SelectedRows& input1,
framework::Tensor* input2);
};
} // namespace scatter
} // namespace math } // namespace math
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
import unittest import unittest
import numpy as np import numpy as np
from op_test import OpTest from op_test import OpTest
from paddle.v2.fluid import core
from paddle.v2.fluid.op import Operator
class TestAdamOp1(OpTest): class TestAdamOp1(OpTest):
...@@ -176,5 +178,124 @@ def adam_step(inputs, attributes): ...@@ -176,5 +178,124 @@ def adam_step(inputs, attributes):
return param_out, moment1_out, moment2_out return param_out, moment1_out, moment2_out
def adam_step_sparse(inputs, attributes, height, rows, row_numel, np_grad):
'''
Simulate one step of the adam optimizer
:param inputs: dict of inputs
:param attributes: dict of attributes
:return tuple: tuple of output param, moment1, moment2,
beta1 power accumulator and beta2 power accumulator
'''
param = inputs['Param']
# grad = inputs['Grad']
moment1 = inputs['Moment1']
moment2 = inputs['Moment2']
lr = inputs['LearningRate']
beta1_pow = inputs['Beta1Pow']
beta2_pow = inputs['Beta2Pow']
beta1 = attributes['beta1']
beta2 = attributes['beta2']
epsilon = attributes['epsilon']
moment1_out = np.zeros(shape=[height, row_numel])
moment2_out = np.zeros(shape=[height, row_numel])
param_out = np.zeros(shape=[height, row_numel])
for idx, row_id in enumerate(rows):
moment1_out[row_id] = beta1 * moment1[row_id] + (1 - beta1
) * np_grad[idx]
moment2_out[row_id] = beta2 * moment2[row_id] + (
1 - beta2) * np.square(np_grad[idx])
lr_t = lr * np.sqrt(1 - beta2_pow) / (1 - beta1_pow)
param_out[row_id] = param[row_id] - lr_t * (moment1_out[row_id] / (
np.sqrt(moment2_out[row_id]) + epsilon))
return param_out, moment1_out, moment2_out
class TestSparseAdamOp(unittest.TestCase):
def setup(self, scope, place):
beta1 = 0.78
beta2 = 0.836
epsilon = 1e-4
height = 10
rows = [0, 4, 7]
self.rows = rows
row_numel = 12
self.row_numel = row_numel
self.dense_inputs = {
"Param": np.full((height, row_numel), 5.0).astype("float32"),
"Moment1": np.full((height, row_numel), 5.0).astype("float32"),
"Moment2": np.full((height, row_numel), 5.0).astype("float32"),
'Beta1Pow': np.array([beta1**10]).astype("float32"),
'Beta2Pow': np.array([beta2**10]).astype("float32"),
"LearningRate": np.full((1), 2.0).astype("float32")
}
self.attrs = {'epsilon': epsilon, 'beta1': beta1, 'beta2': beta2}
grad_selected_rows = scope.var('Grad').get_selected_rows()
grad_selected_rows.set_height(height)
grad_selected_rows.set_rows(rows)
np_array = np.ones((len(rows), row_numel)).astype("float32")
np_array[0, 0] = 2.0
np_array[2, 8] = 4.0
grad_tensor = grad_selected_rows.get_tensor()
grad_tensor.set(np_array, place)
self.sparse_inputs = ["Grad"]
param_out, mom1, mom2 = adam_step_sparse(
self.dense_inputs, self.attrs, height, rows, row_numel, np_array)
self.outputs = {
"ParamOut": param_out,
"Moment1Out": mom1,
"Moment2Out": mom2
}
def check_with_place(self, place):
scope = core.Scope()
self.setup(scope, place)
op_args = dict()
for key, np_array in self.dense_inputs.iteritems():
var = scope.var(key).get_tensor()
var.set(np_array, place)
op_args[key] = key
for s in self.sparse_inputs:
op_args[s] = s
for s in self.outputs:
var = scope.var(s).get_tensor()
var.set(self.outputs[s], place)
op_args[s] = s
for k in self.attrs:
op_args[k] = self.attrs[k]
# create and run sgd operator
adam_op = Operator("adam", **op_args)
adam_op.run(scope, place)
for key, np_array in self.outputs.iteritems():
out_var = scope.var(key).get_tensor()
actual = np.array(out_var)
actual = actual.reshape([actual.size])
np_array = np_array.reshape([np_array.size])
for idx, row_id in enumerate(self.rows):
j = 0
while j < self.row_numel:
pos = row_id * self.row_numel + j
self.assertLess((actual[pos] - np_array[pos]) / actual[pos],
0.00001)
j += 1
def test_sparse_sgd(self):
places = [core.CPUPlace()]
if core.is_compile_gpu():
places.append(core.CUDAPlace(0))
for place in places:
self.check_with_place(place)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册