提交 ebeee930 编写于 作者: S shippingwang

Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into shufflechannel

...@@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include <algorithm>
#include <set> #include <set>
#include <unordered_map> #include <unordered_map>
...@@ -252,23 +253,26 @@ elementwise_add_to(const DeviceContext& ctx, BlasT<DeviceContext, T>* blas, ...@@ -252,23 +253,26 @@ elementwise_add_to(const DeviceContext& ctx, BlasT<DeviceContext, T>* blas,
template <typename T> template <typename T>
struct MergeAdd<platform::CPUDeviceContext, T> { struct MergeAdd<platform::CPUDeviceContext, T> {
framework::SelectedRows operator()(const platform::CPUDeviceContext& context, framework::SelectedRows operator()(const platform::CPUDeviceContext& context,
const framework::SelectedRows& input) { const framework::SelectedRows& input,
const bool sorted_result = false) {
framework::SelectedRows out; framework::SelectedRows out;
(*this)(context, input, &out); (*this)(context, input, &out, sorted_result);
return out; return out;
} }
void operator()(const platform::CPUDeviceContext& context, void operator()(const platform::CPUDeviceContext& context,
const framework::SelectedRows& input, const framework::SelectedRows& input,
framework::SelectedRows* output) { framework::SelectedRows* output,
const bool sorted_result = false) {
std::vector<const framework::SelectedRows*> inputs; std::vector<const framework::SelectedRows*> inputs;
inputs.push_back(&input); inputs.push_back(&input);
(*this)(context, inputs, output); (*this)(context, inputs, output, sorted_result);
} }
void operator()(const platform::CPUDeviceContext& context, void operator()(const platform::CPUDeviceContext& context,
const std::vector<const framework::SelectedRows*>& inputs, const std::vector<const framework::SelectedRows*>& inputs,
framework::SelectedRows* output) { framework::SelectedRows* output,
const bool sorted_result = false) {
if (inputs.size() == 0) { if (inputs.size() == 0) {
VLOG(3) << "no input! return"; VLOG(3) << "no input! return";
return; return;
...@@ -301,6 +305,9 @@ struct MergeAdd<platform::CPUDeviceContext, T> { ...@@ -301,6 +305,9 @@ struct MergeAdd<platform::CPUDeviceContext, T> {
} }
std::vector<int64_t> merge_rows(merged_row_set.begin(), std::vector<int64_t> merge_rows(merged_row_set.begin(),
merged_row_set.end()); merged_row_set.end());
if (sorted_result) {
std::sort(merge_rows.begin(), merge_rows.end());
}
std::unordered_map<int64_t, size_t> rows_to_id; std::unordered_map<int64_t, size_t> rows_to_id;
for (size_t i = 0; i < merge_rows.size(); ++i) { for (size_t i = 0; i < merge_rows.size(); ++i) {
rows_to_id[merge_rows[i]] = i; rows_to_id[merge_rows[i]] = i;
......
...@@ -266,7 +266,8 @@ __global__ void MergeAddKernel(const T* input, const int64_t* input_rows, ...@@ -266,7 +266,8 @@ __global__ void MergeAddKernel(const T* input, const int64_t* input_rows,
template <typename T> template <typename T>
struct MergeAdd<platform::CUDADeviceContext, T> { struct MergeAdd<platform::CUDADeviceContext, T> {
framework::SelectedRows operator()(const platform::CUDADeviceContext& context, framework::SelectedRows operator()(const platform::CUDADeviceContext& context,
const framework::SelectedRows& input) { const framework::SelectedRows& input,
const bool sorted_result = false) {
framework::SelectedRows out; framework::SelectedRows out;
(*this)(context, input, &out); (*this)(context, input, &out);
return out; return out;
...@@ -274,7 +275,8 @@ struct MergeAdd<platform::CUDADeviceContext, T> { ...@@ -274,7 +275,8 @@ struct MergeAdd<platform::CUDADeviceContext, T> {
void operator()(const platform::CUDADeviceContext& context, void operator()(const platform::CUDADeviceContext& context,
const framework::SelectedRows& input, const framework::SelectedRows& input,
framework::SelectedRows* output) { framework::SelectedRows* output,
const bool sorted_result = false) {
framework::Vector<int64_t> input_rows(input.rows()); framework::Vector<int64_t> input_rows(input.rows());
if (input_rows.size() == 0) { if (input_rows.size() == 0) {
return; return;
...@@ -312,7 +314,8 @@ struct MergeAdd<platform::CUDADeviceContext, T> { ...@@ -312,7 +314,8 @@ struct MergeAdd<platform::CUDADeviceContext, T> {
void operator()(const platform::CUDADeviceContext& context, void operator()(const platform::CUDADeviceContext& context,
const std::vector<const framework::SelectedRows*>& inputs, const std::vector<const framework::SelectedRows*>& inputs,
framework::SelectedRows* output) { framework::SelectedRows* output,
const bool sorted_result = false) {
if (inputs.size() == 0) { if (inputs.size() == 0) {
VLOG(3) << "no input! return"; VLOG(3) << "no input! return";
return; return;
......
...@@ -81,13 +81,16 @@ struct MergeAdd { ...@@ -81,13 +81,16 @@ struct MergeAdd {
// unary functor, merge by adding duplicated rows in // unary functor, merge by adding duplicated rows in
// the input SelectedRows object. // the input SelectedRows object.
framework::SelectedRows operator()(const DeviceContext& context, framework::SelectedRows operator()(const DeviceContext& context,
const framework::SelectedRows& input); const framework::SelectedRows& input,
const bool sorted_result = false);
void operator()(const DeviceContext& context, void operator()(const DeviceContext& context,
const framework::SelectedRows& input, const framework::SelectedRows& input,
framework::SelectedRows* output); framework::SelectedRows* output,
const bool sorted_result = false);
void operator()(const DeviceContext& context, void operator()(const DeviceContext& context,
const std::vector<const framework::SelectedRows*>& inputs, const std::vector<const framework::SelectedRows*>& inputs,
framework::SelectedRows* output); framework::SelectedRows* output,
const bool sorted_result = false);
}; };
enum class ScatterOps { ASSIGN, ADD, SUB, SUBBY, MUL, DIV, DIVBY }; enum class ScatterOps { ASSIGN, ADD, SUB, SUBBY, MUL, DIV, DIVBY };
......
...@@ -157,8 +157,11 @@ struct AdamFunctor<T, CPUAdam> { ...@@ -157,8 +157,11 @@ struct AdamFunctor<T, CPUAdam> {
} }
}; };
template <typename T, typename Flavour>
struct SparseAdamFunctor;
template <typename T> template <typename T>
struct SparseAdamFunctor { struct SparseAdamFunctor<T, GPUAdam> {
T beta1_; T beta1_;
T beta2_; T beta2_;
T epsilon_; T epsilon_;
...@@ -236,6 +239,106 @@ struct SparseAdamFunctor { ...@@ -236,6 +239,106 @@ struct SparseAdamFunctor {
} }
}; };
template <typename T>
struct SparseAdamFunctor<T, CPUAdam> {
T beta1_;
T beta2_;
T epsilon_;
const T* beta1_pow_;
const T* beta2_pow_;
const T* moment1_;
T* moment1_out_;
const T* moment2_;
T* moment2_out_;
const T* lr_;
const T* grad_;
const T* param_;
T* param_out_;
const int64_t* rows_;
int64_t row_numel_;
int64_t row_count_;
SparseAdamFunctor(T beta1, T beta2, T epsilon, const T* beta1_pow,
const T* beta2_pow, const T* mom1, T* mom1_out,
const T* mom2, T* mom2_out, const T* lr, const T* grad,
const T* param, T* param_out, const int64_t* rows,
int64_t row_numel, int64_t row_count, bool lazy_mode)
: beta1_(beta1),
beta2_(beta2),
epsilon_(epsilon),
beta1_pow_(beta1_pow),
beta2_pow_(beta2_pow),
moment1_(mom1),
moment1_out_(mom1_out),
moment2_(mom2),
moment2_out_(mom2_out),
lr_(lr),
grad_(grad),
param_(param),
param_out_(param_out),
rows_(rows),
row_numel_(row_numel),
row_count_(row_count) {}
inline HOSTDEVICE void adam_update(size_t i, T g) const {
// The following code is the same as dense
T mom1 = moment1_[i];
T mom2 = moment2_[i];
T lr = *lr_;
T beta1_pow = *beta1_pow_;
T beta2_pow = *beta2_pow_;
T p = param_[i];
// Calculation
lr *= sqrt(1 - beta2_pow) / (1 - beta1_pow);
mom1 = beta1_ * mom1 + (1 - beta1_) * g;
mom2 = beta2_ * mom2 + (1 - beta2_) * g * g;
p -= lr * (mom1 / (sqrt(mom2) + epsilon_));
// Write back to global memory
moment1_out_[i] = mom1;
moment2_out_[i] = mom2;
param_out_[i] = p;
}
inline void operator()(size_t numel) const {
// lr could be reuse
T lr = *lr_;
T beta1_pow = *beta1_pow_;
T beta2_pow = *beta2_pow_;
lr *= sqrt(1 - beta2_pow) / (1 - beta1_pow);
size_t row_count = numel / row_numel_;
for (size_t i = 0U, j = 0U; i != row_count; ++i) {
if (i == *(rows_ + j)) {
for (size_t k = 0U; k != row_numel_; ++k) {
T g = grad_[j * row_numel_ + k];
adam_update(i * row_numel_ + k, g);
}
++j;
} else {
for (size_t k = 0U; k != row_numel_; ++k) {
T mom1 = moment1_[i * row_numel_ + k];
T mom2 = moment2_[i * row_numel_ + k];
T p = param_[i * row_numel_ + k];
mom1 = beta1_ * mom1;
mom2 = beta2_ * mom2;
p -= lr * (mom1 / (sqrt(mom2) + epsilon_));
// Write back to global memory
moment1_out_[i * row_numel_ + k] = mom1;
moment2_out_[i * row_numel_ + k] = mom2;
param_out_[i * row_numel_ + k] = p;
}
}
}
}
};
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
class AdamOpKernel : public framework::OpKernel<T> { class AdamOpKernel : public framework::OpKernel<T> {
public: public:
...@@ -331,7 +434,7 @@ class AdamOpKernel : public framework::OpKernel<T> { ...@@ -331,7 +434,7 @@ class AdamOpKernel : public framework::OpKernel<T> {
.Var() .Var()
->GetMutable<framework::SelectedRows>(); ->GetMutable<framework::SelectedRows>();
merge_func(ctx.template device_context<DeviceContext>(), grad, merge_func(ctx.template device_context<DeviceContext>(), grad,
grad_merge_var); grad_merge_var, true);
grad_merge_ptr = grad_merge_var; grad_merge_ptr = grad_merge_var;
} }
...@@ -347,13 +450,13 @@ class AdamOpKernel : public framework::OpKernel<T> { ...@@ -347,13 +450,13 @@ class AdamOpKernel : public framework::OpKernel<T> {
} else { } else {
#endif #endif
rows = grad_merge.rows().data(); rows = grad_merge.rows().data();
#if defined(PADDLE_WITH_CUDA) #if defined(PADDLE_WITH_CUDA)
} }
#endif #endif
auto row_numel = grad_tensor.numel() / grad_merge.rows().size(); auto row_numel = grad_tensor.numel() / grad_merge.rows().size();
SparseAdamFunctor<T> functor( if (platform::is_cpu_place(ctx.GetPlace())) {
SparseAdamFunctor<T, CPUAdam> functor(
beta1, beta2, epsilon, beta1_pow.template data<T>(), beta1, beta2, epsilon, beta1_pow.template data<T>(),
beta2_pow.template data<T>(), mom1.template data<T>(), beta2_pow.template data<T>(), mom1.template data<T>(),
mom1_out.template mutable_data<T>(ctx.GetPlace()), mom1_out.template mutable_data<T>(ctx.GetPlace()),
...@@ -362,8 +465,8 @@ class AdamOpKernel : public framework::OpKernel<T> { ...@@ -362,8 +465,8 @@ class AdamOpKernel : public framework::OpKernel<T> {
lr.template data<T>(), grad_data, param.template data<T>(), lr.template data<T>(), grad_data, param.template data<T>(),
param_out.template mutable_data<T>(ctx.GetPlace()), rows, row_numel, param_out.template mutable_data<T>(ctx.GetPlace()), rows, row_numel,
grad_merge.rows().size(), lazy_mode); grad_merge.rows().size(), lazy_mode);
VLOG(3) << "lazy_mode :" << lazy_mode;
if (lazy_mode && platform::is_cpu_place(ctx.GetPlace())) { if (lazy_mode) {
size_t row_count = grad_merge.rows().size(); size_t row_count = grad_merge.rows().size();
std::vector<int64_t> cpu_rows(grad_merge.rows()); std::vector<int64_t> cpu_rows(grad_merge.rows());
for (size_t row_index = 0; row_index < row_count; ++row_index) { for (size_t row_index = 0; row_index < row_count; ++row_index) {
...@@ -373,6 +476,20 @@ class AdamOpKernel : public framework::OpKernel<T> { ...@@ -373,6 +476,20 @@ class AdamOpKernel : public framework::OpKernel<T> {
} }
} }
} else { } else {
functor(param.numel());
}
} else if (platform::is_gpu_place(ctx.GetPlace())) {
SparseAdamFunctor<T, GPUAdam> functor(
beta1, beta2, epsilon, beta1_pow.template data<T>(),
beta2_pow.template data<T>(), mom1.template data<T>(),
mom1_out.template mutable_data<T>(ctx.GetPlace()),
mom2.template data<T>(),
mom2_out.template mutable_data<T>(ctx.GetPlace()),
lr.template data<T>(), grad_data, param.template data<T>(),
param_out.template mutable_data<T>(ctx.GetPlace()), rows, row_numel,
grad_merge.rows().size(), lazy_mode);
// FIXME(minqiyang): remove BinarySearch in GPU later
platform::ForRange<DeviceContext> for_range( platform::ForRange<DeviceContext> for_range(
static_cast<const DeviceContext&>(ctx.device_context()), static_cast<const DeviceContext&>(ctx.device_context()),
param.numel()); param.numel());
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册