提交 5fea8cd4 编写于 作者: M minqiyang

Add sorted_result parameter to SelectedRows Functor

test=develop
上级 da796dfe
...@@ -253,23 +253,26 @@ elementwise_add_to(const DeviceContext& ctx, BlasT<DeviceContext, T>* blas, ...@@ -253,23 +253,26 @@ elementwise_add_to(const DeviceContext& ctx, BlasT<DeviceContext, T>* blas,
template <typename T> template <typename T>
struct MergeAdd<platform::CPUDeviceContext, T> { struct MergeAdd<platform::CPUDeviceContext, T> {
framework::SelectedRows operator()(const platform::CPUDeviceContext& context, framework::SelectedRows operator()(const platform::CPUDeviceContext& context,
const framework::SelectedRows& input) { const framework::SelectedRows& input,
const bool sorted_result = false) {
framework::SelectedRows out; framework::SelectedRows out;
(*this)(context, input, &out); (*this)(context, input, &out, sorted_result);
return out; return out;
} }
void operator()(const platform::CPUDeviceContext& context, void operator()(const platform::CPUDeviceContext& context,
const framework::SelectedRows& input, const framework::SelectedRows& input,
framework::SelectedRows* output) { framework::SelectedRows* output,
const bool sorted_result = false) {
std::vector<const framework::SelectedRows*> inputs; std::vector<const framework::SelectedRows*> inputs;
inputs.push_back(&input); inputs.push_back(&input);
(*this)(context, inputs, output); (*this)(context, inputs, output, sorted_result);
} }
void operator()(const platform::CPUDeviceContext& context, void operator()(const platform::CPUDeviceContext& context,
const std::vector<const framework::SelectedRows*>& inputs, const std::vector<const framework::SelectedRows*>& inputs,
framework::SelectedRows* output) { framework::SelectedRows* output,
const bool sorted_result = false) {
if (inputs.size() == 0) { if (inputs.size() == 0) {
VLOG(3) << "no input! return"; VLOG(3) << "no input! return";
return; return;
...@@ -302,8 +305,8 @@ struct MergeAdd<platform::CPUDeviceContext, T> { ...@@ -302,8 +305,8 @@ struct MergeAdd<platform::CPUDeviceContext, T> {
} }
std::vector<int64_t> merge_rows(merged_row_set.begin(), std::vector<int64_t> merge_rows(merged_row_set.begin(),
merged_row_set.end()); merged_row_set.end());
if (sorted_result_) { if (sorted_result) {
std::sort(merge_rows); std::sort(merge_rows.begin(), merge_rows.end());
} }
std::unordered_map<int64_t, size_t> rows_to_id; std::unordered_map<int64_t, size_t> rows_to_id;
for (size_t i = 0; i < merge_rows.size(); ++i) { for (size_t i = 0; i < merge_rows.size(); ++i) {
......
...@@ -266,7 +266,8 @@ __global__ void MergeAddKernel(const T* input, const int64_t* input_rows, ...@@ -266,7 +266,8 @@ __global__ void MergeAddKernel(const T* input, const int64_t* input_rows,
template <typename T> template <typename T>
struct MergeAdd<platform::CUDADeviceContext, T> { struct MergeAdd<platform::CUDADeviceContext, T> {
framework::SelectedRows operator()(const platform::CUDADeviceContext& context, framework::SelectedRows operator()(const platform::CUDADeviceContext& context,
const framework::SelectedRows& input) { const framework::SelectedRows& input,
const bool sorted_result = false) {
framework::SelectedRows out; framework::SelectedRows out;
(*this)(context, input, &out); (*this)(context, input, &out);
return out; return out;
......
...@@ -78,23 +78,19 @@ namespace scatter { ...@@ -78,23 +78,19 @@ namespace scatter {
// functors for manuplating SelectedRows data // functors for manuplating SelectedRows data
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
struct MergeAdd { struct MergeAdd {
MergeAdd() : sorted_result_(false) {}
explicit MergeAdd(bool sorted_result) : sorted_result_(sorted_result) {}
// unary functor, merge by adding duplicated rows in // unary functor, merge by adding duplicated rows in
// the input SelectedRows object. // the input SelectedRows object.
framework::SelectedRows operator()(const DeviceContext& context, framework::SelectedRows operator()(const DeviceContext& context,
const framework::SelectedRows& input); const framework::SelectedRows& input,
const bool sorted_result = false);
void operator()(const DeviceContext& context, void operator()(const DeviceContext& context,
const framework::SelectedRows& input, const framework::SelectedRows& input,
framework::SelectedRows* output); framework::SelectedRows* output,
const bool sorted_result = false);
void operator()(const DeviceContext& context, void operator()(const DeviceContext& context,
const std::vector<const framework::SelectedRows*>& inputs, const std::vector<const framework::SelectedRows*>& inputs,
framework::SelectedRows* output); framework::SelectedRows* output,
const bool sorted_result = false);
private:
bool sorted_result_;
}; };
enum class ScatterOps { ASSIGN, ADD, SUB, SUBBY, MUL, DIV, DIVBY }; enum class ScatterOps { ASSIGN, ADD, SUB, SUBBY, MUL, DIV, DIVBY };
......
...@@ -157,6 +157,9 @@ struct AdamFunctor<T, CPUAdam> { ...@@ -157,6 +157,9 @@ struct AdamFunctor<T, CPUAdam> {
} }
}; };
template <typename T, typename Flavour>
struct SparseAdamFunctor;
template <typename T> template <typename T>
struct SparseAdamFunctor<T, GPUAdam> { struct SparseAdamFunctor<T, GPUAdam> {
T beta1_; T beta1_;
...@@ -283,6 +286,7 @@ struct SparseAdamFunctor<T, CPUAdam> { ...@@ -283,6 +286,7 @@ struct SparseAdamFunctor<T, CPUAdam> {
// Calculation // Calculation
if (i == *(rows_ + j)) { if (i == *(rows_ + j)) {
T g = grad_[j * row_numel_];
mom1 = beta1_ * mom1 + (1 - beta1_) * g; mom1 = beta1_ * mom1 + (1 - beta1_) * g;
mom2 = beta2_ * mom2 + (1 - beta2_) * g * g; mom2 = beta2_ * mom2 + (1 - beta2_) * g * g;
++j; ++j;
...@@ -388,12 +392,12 @@ class AdamOpKernel : public framework::OpKernel<T> { ...@@ -388,12 +392,12 @@ class AdamOpKernel : public framework::OpKernel<T> {
} else { } else {
// merge duplicated rows if any. // merge duplicated rows if any.
// The rows of grad_merge have been sorted inside MergeAdd functor // The rows of grad_merge have been sorted inside MergeAdd functor
scatter::MergeAdd<DeviceContext, T> merge_func(true); scatter::MergeAdd<DeviceContext, T> merge_func;
auto* grad_merge_var = const_cast<framework::Scope&>(ctx.scope()) auto* grad_merge_var = const_cast<framework::Scope&>(ctx.scope())
.Var() .Var()
->GetMutable<framework::SelectedRows>(); ->GetMutable<framework::SelectedRows>();
merge_func(ctx.template device_context<DeviceContext>(), grad, merge_func(ctx.template device_context<DeviceContext>(), grad,
grad_merge_var); grad_merge_var, true);
grad_merge_ptr = grad_merge_var; grad_merge_ptr = grad_merge_var;
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册