提交 da796dfe 编写于 作者: M minqiyang

Remove BinarySearch from Adam Op

test=develop
上级 e2130502
...@@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include <algorithm>
#include <set> #include <set>
#include <unordered_map> #include <unordered_map>
...@@ -301,6 +302,9 @@ struct MergeAdd<platform::CPUDeviceContext, T> { ...@@ -301,6 +302,9 @@ struct MergeAdd<platform::CPUDeviceContext, T> {
} }
std::vector<int64_t> merge_rows(merged_row_set.begin(), std::vector<int64_t> merge_rows(merged_row_set.begin(),
merged_row_set.end()); merged_row_set.end());
if (sorted_result_) {
std::sort(merge_rows);
}
std::unordered_map<int64_t, size_t> rows_to_id; std::unordered_map<int64_t, size_t> rows_to_id;
for (size_t i = 0; i < merge_rows.size(); ++i) { for (size_t i = 0; i < merge_rows.size(); ++i) {
rows_to_id[merge_rows[i]] = i; rows_to_id[merge_rows[i]] = i;
......
...@@ -78,6 +78,10 @@ namespace scatter { ...@@ -78,6 +78,10 @@ namespace scatter {
// functors for manuplating SelectedRows data // functors for manuplating SelectedRows data
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
struct MergeAdd { struct MergeAdd {
MergeAdd() : sorted_result_(false) {}
explicit MergeAdd(bool sorted_result) : sorted_result_(sorted_result) {}
// unary functor, merge by adding duplicated rows in // unary functor, merge by adding duplicated rows in
// the input SelectedRows object. // the input SelectedRows object.
framework::SelectedRows operator()(const DeviceContext& context, framework::SelectedRows operator()(const DeviceContext& context,
...@@ -88,6 +92,9 @@ struct MergeAdd { ...@@ -88,6 +92,9 @@ struct MergeAdd {
void operator()(const DeviceContext& context, void operator()(const DeviceContext& context,
const std::vector<const framework::SelectedRows*>& inputs, const std::vector<const framework::SelectedRows*>& inputs,
framework::SelectedRows* output); framework::SelectedRows* output);
private:
bool sorted_result_;
}; };
enum class ScatterOps { ASSIGN, ADD, SUB, SUBBY, MUL, DIV, DIVBY }; enum class ScatterOps { ASSIGN, ADD, SUB, SUBBY, MUL, DIV, DIVBY };
......
...@@ -158,7 +158,7 @@ struct AdamFunctor<T, CPUAdam> { ...@@ -158,7 +158,7 @@ struct AdamFunctor<T, CPUAdam> {
}; };
template <typename T> template <typename T>
struct SparseAdamFunctor { struct SparseAdamFunctor<T, GPUAdam> {
T beta1_; T beta1_;
T beta2_; T beta2_;
T epsilon_; T epsilon_;
...@@ -227,6 +227,78 @@ struct SparseAdamFunctor { ...@@ -227,6 +227,78 @@ struct SparseAdamFunctor {
} }
}; };
template <typename T>
struct SparseAdamFunctor<T, CPUAdam> {
T beta1_;
T beta2_;
T epsilon_;
const T* beta1_pow_;
const T* beta2_pow_;
const T* moment1_;
T* moment1_out_;
const T* moment2_;
T* moment2_out_;
const T* lr_;
const T* grad_;
const T* param_;
T* param_out_;
const int64_t* rows_;
int64_t row_numel_;
int64_t row_count_;
SparseAdamFunctor(T beta1, T beta2, T epsilon, const T* beta1_pow,
const T* beta2_pow, const T* mom1, T* mom1_out,
const T* mom2, T* mom2_out, const T* lr, const T* grad,
const T* param, T* param_out, const int64_t* rows,
int64_t row_numel, int64_t row_count)
: beta1_(beta1),
beta2_(beta2),
epsilon_(epsilon),
beta1_pow_(beta1_pow),
beta2_pow_(beta2_pow),
moment1_(mom1),
moment1_out_(mom1_out),
moment2_(mom2),
moment2_out_(mom2_out),
lr_(lr),
grad_(grad),
param_(param),
param_out_(param_out),
rows_(rows),
row_numel_(row_numel),
row_count_(row_count) {}
inline void operator()(size_t numel) const {
// lr could be reuse
T lr = *lr_;
T beta1_pow = *beta1_pow_;
T beta2_pow = *beta2_pow_;
lr *= sqrt(1 - beta2_pow) / (1 - beta1_pow);
for (size_t i = 0U, j = 0U; i != numel; ++i) {
T mom1 = moment1_[i];
T mom2 = moment2_[i];
T p = param_[i];
// Calculation
if (i == *(rows_ + j)) {
mom1 = beta1_ * mom1 + (1 - beta1_) * g;
mom2 = beta2_ * mom2 + (1 - beta2_) * g * g;
++j;
} else {
mom1 = beta1_ * mom1;
mom2 = beta2_ * mom2;
}
p -= lr * (mom1 / (sqrt(mom2) + epsilon_));
// Write back to global memory
moment1_out_[i] = mom1;
moment2_out_[i] = mom2;
param_out_[i] = p;
}
}
};
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
class AdamOpKernel : public framework::OpKernel<T> { class AdamOpKernel : public framework::OpKernel<T> {
public: public:
...@@ -316,7 +388,7 @@ class AdamOpKernel : public framework::OpKernel<T> { ...@@ -316,7 +388,7 @@ class AdamOpKernel : public framework::OpKernel<T> {
} else { } else {
// merge duplicated rows if any. // merge duplicated rows if any.
// The rows of grad_merge have been sorted inside MergeAdd functor // The rows of grad_merge have been sorted inside MergeAdd functor
scatter::MergeAdd<DeviceContext, T> merge_func; scatter::MergeAdd<DeviceContext, T> merge_func(true);
auto* grad_merge_var = const_cast<framework::Scope&>(ctx.scope()) auto* grad_merge_var = const_cast<framework::Scope&>(ctx.scope())
.Var() .Var()
->GetMutable<framework::SelectedRows>(); ->GetMutable<framework::SelectedRows>();
...@@ -337,13 +409,13 @@ class AdamOpKernel : public framework::OpKernel<T> { ...@@ -337,13 +409,13 @@ class AdamOpKernel : public framework::OpKernel<T> {
} else { } else {
#endif #endif
rows = grad_merge.rows().data(); rows = grad_merge.rows().data();
#if defined(PADDLE_WITH_CUDA) #if defined(PADDLE_WITH_CUDA)
} }
#endif #endif
auto row_numel = grad_tensor.numel() / grad_merge.rows().size(); auto row_numel = grad_tensor.numel() / grad_merge.rows().size();
SparseAdamFunctor<T> functor( if (platform::is_cpu_place(ctx.GetPlace())) {
SparseAdamFunctor<T, CPUAdam> functor(
beta1, beta2, epsilon, beta1_pow.template data<T>(), beta1, beta2, epsilon, beta1_pow.template data<T>(),
beta2_pow.template data<T>(), mom1.template data<T>(), beta2_pow.template data<T>(), mom1.template data<T>(),
mom1_out.template mutable_data<T>(ctx.GetPlace()), mom1_out.template mutable_data<T>(ctx.GetPlace()),
...@@ -352,10 +424,25 @@ class AdamOpKernel : public framework::OpKernel<T> { ...@@ -352,10 +424,25 @@ class AdamOpKernel : public framework::OpKernel<T> {
lr.template data<T>(), grad_data, param.template data<T>(), lr.template data<T>(), grad_data, param.template data<T>(),
param_out.template mutable_data<T>(ctx.GetPlace()), rows, row_numel, param_out.template mutable_data<T>(ctx.GetPlace()), rows, row_numel,
grad_merge.rows().size()); grad_merge.rows().size());
functor(param.numel());
} else if (platform::is_gpu_place(ctx.GetPlace())) {
SparseAdamFunctor<T, GPUAdam> functor(
beta1, beta2, epsilon, beta1_pow.template data<T>(),
beta2_pow.template data<T>(), mom1.template data<T>(),
mom1_out.template mutable_data<T>(ctx.GetPlace()),
mom2.template data<T>(),
mom2_out.template mutable_data<T>(ctx.GetPlace()),
lr.template data<T>(), grad_data, param.template data<T>(),
param_out.template mutable_data<T>(ctx.GetPlace()), rows, row_numel,
grad_merge.rows().size());
// FIXME(minqiyang): remove BinarySearch in GPU later
platform::ForRange<DeviceContext> for_range( platform::ForRange<DeviceContext> for_range(
static_cast<const DeviceContext&>(ctx.device_context()), static_cast<const DeviceContext&>(ctx.device_context()),
param.numel()); param.numel());
for_range(functor); for_range(functor);
}
} else { } else {
PADDLE_THROW("Variable type not supported by adam_op"); PADDLE_THROW("Variable type not supported by adam_op");
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册