提交 5fea8cd4 编写于 作者: M minqiyang

Add sorted_result parameter to SelectedRows Functor

test=develop
上级 da796dfe
......@@ -253,23 +253,26 @@ elementwise_add_to(const DeviceContext& ctx, BlasT<DeviceContext, T>* blas,
template <typename T>
struct MergeAdd<platform::CPUDeviceContext, T> {
framework::SelectedRows operator()(const platform::CPUDeviceContext& context,
const framework::SelectedRows& input) {
const framework::SelectedRows& input,
const bool sorted_result = false) {
framework::SelectedRows out;
(*this)(context, input, &out);
(*this)(context, input, &out, sorted_result);
return out;
}
void operator()(const platform::CPUDeviceContext& context,
const framework::SelectedRows& input,
framework::SelectedRows* output) {
framework::SelectedRows* output,
const bool sorted_result = false) {
std::vector<const framework::SelectedRows*> inputs;
inputs.push_back(&input);
(*this)(context, inputs, output);
(*this)(context, inputs, output, sorted_result);
}
void operator()(const platform::CPUDeviceContext& context,
const std::vector<const framework::SelectedRows*>& inputs,
framework::SelectedRows* output) {
framework::SelectedRows* output,
const bool sorted_result = false) {
if (inputs.size() == 0) {
VLOG(3) << "no input! return";
return;
......@@ -302,8 +305,8 @@ struct MergeAdd<platform::CPUDeviceContext, T> {
}
std::vector<int64_t> merge_rows(merged_row_set.begin(),
merged_row_set.end());
if (sorted_result_) {
std::sort(merge_rows);
if (sorted_result) {
std::sort(merge_rows.begin(), merge_rows.end());
}
std::unordered_map<int64_t, size_t> rows_to_id;
for (size_t i = 0; i < merge_rows.size(); ++i) {
......
......@@ -266,7 +266,8 @@ __global__ void MergeAddKernel(const T* input, const int64_t* input_rows,
template <typename T>
struct MergeAdd<platform::CUDADeviceContext, T> {
framework::SelectedRows operator()(const platform::CUDADeviceContext& context,
const framework::SelectedRows& input) {
const framework::SelectedRows& input,
const bool sorted_result = false) {
framework::SelectedRows out;
(*this)(context, input, &out);
return out;
......
......@@ -78,23 +78,19 @@ namespace scatter {
// functors for manuplating SelectedRows data
template <typename DeviceContext, typename T>
struct MergeAdd {
MergeAdd() : sorted_result_(false) {}
explicit MergeAdd(bool sorted_result) : sorted_result_(sorted_result) {}
// unary functor, merge by adding duplicated rows in
// the input SelectedRows object.
framework::SelectedRows operator()(const DeviceContext& context,
const framework::SelectedRows& input);
const framework::SelectedRows& input,
const bool sorted_result = false);
void operator()(const DeviceContext& context,
const framework::SelectedRows& input,
framework::SelectedRows* output);
framework::SelectedRows* output,
const bool sorted_result = false);
void operator()(const DeviceContext& context,
const std::vector<const framework::SelectedRows*>& inputs,
framework::SelectedRows* output);
private:
bool sorted_result_;
framework::SelectedRows* output,
const bool sorted_result = false);
};
enum class ScatterOps { ASSIGN, ADD, SUB, SUBBY, MUL, DIV, DIVBY };
......
......@@ -157,6 +157,9 @@ struct AdamFunctor<T, CPUAdam> {
}
};
template <typename T, typename Flavour>
struct SparseAdamFunctor;
template <typename T>
struct SparseAdamFunctor<T, GPUAdam> {
T beta1_;
......@@ -283,6 +286,7 @@ struct SparseAdamFunctor<T, CPUAdam> {
// Calculation
if (i == *(rows_ + j)) {
T g = grad_[j * row_numel_];
mom1 = beta1_ * mom1 + (1 - beta1_) * g;
mom2 = beta2_ * mom2 + (1 - beta2_) * g * g;
++j;
......@@ -388,12 +392,12 @@ class AdamOpKernel : public framework::OpKernel<T> {
} else {
// merge duplicated rows if any.
// The rows of grad_merge have been sorted inside MergeAdd functor
scatter::MergeAdd<DeviceContext, T> merge_func(true);
scatter::MergeAdd<DeviceContext, T> merge_func;
auto* grad_merge_var = const_cast<framework::Scope&>(ctx.scope())
.Var()
->GetMutable<framework::SelectedRows>();
merge_func(ctx.template device_context<DeviceContext>(), grad,
grad_merge_var);
grad_merge_var, true);
grad_merge_ptr = grad_merge_var;
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册