提交 da796dfe 编写于 作者: M minqiyang

Remove BinarySearch from Adam Op

test=develop
上级 e2130502
......@@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <algorithm>
#include <set>
#include <unordered_map>
......@@ -301,6 +302,9 @@ struct MergeAdd<platform::CPUDeviceContext, T> {
}
std::vector<int64_t> merge_rows(merged_row_set.begin(),
merged_row_set.end());
if (sorted_result_) {
std::sort(merge_rows);
}
std::unordered_map<int64_t, size_t> rows_to_id;
for (size_t i = 0; i < merge_rows.size(); ++i) {
rows_to_id[merge_rows[i]] = i;
......
......@@ -78,6 +78,10 @@ namespace scatter {
// functors for manuplating SelectedRows data
template <typename DeviceContext, typename T>
struct MergeAdd {
MergeAdd() : sorted_result_(false) {}
explicit MergeAdd(bool sorted_result) : sorted_result_(sorted_result) {}
// unary functor, merge by adding duplicated rows in
// the input SelectedRows object.
framework::SelectedRows operator()(const DeviceContext& context,
......@@ -88,6 +92,9 @@ struct MergeAdd {
void operator()(const DeviceContext& context,
const std::vector<const framework::SelectedRows*>& inputs,
framework::SelectedRows* output);
private:
bool sorted_result_;
};
enum class ScatterOps { ASSIGN, ADD, SUB, SUBBY, MUL, DIV, DIVBY };
......
......@@ -158,7 +158,7 @@ struct AdamFunctor<T, CPUAdam> {
};
template <typename T>
struct SparseAdamFunctor {
struct SparseAdamFunctor<T, GPUAdam> {
T beta1_;
T beta2_;
T epsilon_;
......@@ -227,6 +227,78 @@ struct SparseAdamFunctor {
}
};
template <typename T>
struct SparseAdamFunctor<T, CPUAdam> {
T beta1_;
T beta2_;
T epsilon_;
const T* beta1_pow_;
const T* beta2_pow_;
const T* moment1_;
T* moment1_out_;
const T* moment2_;
T* moment2_out_;
const T* lr_;
const T* grad_;
const T* param_;
T* param_out_;
const int64_t* rows_;
int64_t row_numel_;
int64_t row_count_;
SparseAdamFunctor(T beta1, T beta2, T epsilon, const T* beta1_pow,
const T* beta2_pow, const T* mom1, T* mom1_out,
const T* mom2, T* mom2_out, const T* lr, const T* grad,
const T* param, T* param_out, const int64_t* rows,
int64_t row_numel, int64_t row_count)
: beta1_(beta1),
beta2_(beta2),
epsilon_(epsilon),
beta1_pow_(beta1_pow),
beta2_pow_(beta2_pow),
moment1_(mom1),
moment1_out_(mom1_out),
moment2_(mom2),
moment2_out_(mom2_out),
lr_(lr),
grad_(grad),
param_(param),
param_out_(param_out),
rows_(rows),
row_numel_(row_numel),
row_count_(row_count) {}
inline void operator()(size_t numel) const {
// lr could be reuse
T lr = *lr_;
T beta1_pow = *beta1_pow_;
T beta2_pow = *beta2_pow_;
lr *= sqrt(1 - beta2_pow) / (1 - beta1_pow);
for (size_t i = 0U, j = 0U; i != numel; ++i) {
T mom1 = moment1_[i];
T mom2 = moment2_[i];
T p = param_[i];
// Calculation
if (i == *(rows_ + j)) {
mom1 = beta1_ * mom1 + (1 - beta1_) * g;
mom2 = beta2_ * mom2 + (1 - beta2_) * g * g;
++j;
} else {
mom1 = beta1_ * mom1;
mom2 = beta2_ * mom2;
}
p -= lr * (mom1 / (sqrt(mom2) + epsilon_));
// Write back to global memory
moment1_out_[i] = mom1;
moment2_out_[i] = mom2;
param_out_[i] = p;
}
}
};
template <typename DeviceContext, typename T>
class AdamOpKernel : public framework::OpKernel<T> {
public:
......@@ -316,7 +388,7 @@ class AdamOpKernel : public framework::OpKernel<T> {
} else {
// merge duplicated rows if any.
// The rows of grad_merge have been sorted inside MergeAdd functor
scatter::MergeAdd<DeviceContext, T> merge_func;
scatter::MergeAdd<DeviceContext, T> merge_func(true);
auto* grad_merge_var = const_cast<framework::Scope&>(ctx.scope())
.Var()
->GetMutable<framework::SelectedRows>();
......@@ -337,13 +409,13 @@ class AdamOpKernel : public framework::OpKernel<T> {
} else {
#endif
rows = grad_merge.rows().data();
#if defined(PADDLE_WITH_CUDA)
}
#endif
auto row_numel = grad_tensor.numel() / grad_merge.rows().size();
SparseAdamFunctor<T> functor(
if (platform::is_cpu_place(ctx.GetPlace())) {
SparseAdamFunctor<T, CPUAdam> functor(
beta1, beta2, epsilon, beta1_pow.template data<T>(),
beta2_pow.template data<T>(), mom1.template data<T>(),
mom1_out.template mutable_data<T>(ctx.GetPlace()),
......@@ -352,10 +424,25 @@ class AdamOpKernel : public framework::OpKernel<T> {
lr.template data<T>(), grad_data, param.template data<T>(),
param_out.template mutable_data<T>(ctx.GetPlace()), rows, row_numel,
grad_merge.rows().size());
functor(param.numel());
} else if (platform::is_gpu_place(ctx.GetPlace())) {
SparseAdamFunctor<T, GPUAdam> functor(
beta1, beta2, epsilon, beta1_pow.template data<T>(),
beta2_pow.template data<T>(), mom1.template data<T>(),
mom1_out.template mutable_data<T>(ctx.GetPlace()),
mom2.template data<T>(),
mom2_out.template mutable_data<T>(ctx.GetPlace()),
lr.template data<T>(), grad_data, param.template data<T>(),
param_out.template mutable_data<T>(ctx.GetPlace()), rows, row_numel,
grad_merge.rows().size());
// FIXME(minqiyang): remove BinarySearch in GPU later
platform::ForRange<DeviceContext> for_range(
static_cast<const DeviceContext&>(ctx.device_context()),
param.numel());
for_range(functor);
}
} else {
PADDLE_THROW("Variable type not supported by adam_op");
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册