未验证 提交 6429d2a8 编写于 作者: Z Zeng Jinle 提交者: GitHub

Merge pull request #16188 from sneaxiy/fix_const_cast

Remove const_cast in optimizers
...@@ -15,6 +15,7 @@ limitations under the License. */ ...@@ -15,6 +15,7 @@ limitations under the License. */
#pragma once #pragma once
#include <math.h> // for sqrt in CPU and CUDA #include <math.h> // for sqrt in CPU and CUDA
#include <Eigen/Dense> #include <Eigen/Dense>
#include <unordered_map>
#include <vector> #include <vector>
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/threadpool.h" #include "paddle/fluid/framework/threadpool.h"
...@@ -311,17 +312,17 @@ struct SparseAdamFunctor<T, CPUAdam> { ...@@ -311,17 +312,17 @@ struct SparseAdamFunctor<T, CPUAdam> {
T beta1_pow = *beta1_pow_; T beta1_pow = *beta1_pow_;
T beta2_pow = *beta2_pow_; T beta2_pow = *beta2_pow_;
lr *= sqrt(1 - beta2_pow) / (1 - beta1_pow); lr *= sqrt(1 - beta2_pow) / (1 - beta1_pow);
size_t row_count = numel / row_numel_; int64_t row_count = static_cast<int64_t>(numel / row_numel_);
for (size_t i = 0U, j = 0U; i != row_count; ++i) { for (int64_t i = 0, j = 0; i != row_count; ++i) {
if (i == *(rows_ + j)) { if (i == *(rows_ + j)) {
for (size_t k = 0U; k != row_numel_; ++k) { for (int64_t k = 0; k != row_numel_; ++k) {
T g = grad_[j * row_numel_ + k]; T g = grad_[j * row_numel_ + k];
adam_update(i * row_numel_ + k, g); adam_update(i * row_numel_ + k, g);
} }
++j; ++j;
} else { } else {
for (size_t k = 0U; k != row_numel_; ++k) { for (int64_t k = 0; k != row_numel_; ++k) {
T mom1 = moment1_[i * row_numel_ + k]; T mom1 = moment1_[i * row_numel_ + k];
T mom2 = moment2_[i * row_numel_ + k]; T mom2 = moment2_[i * row_numel_ + k];
T p = param_[i * row_numel_ + k]; T p = param_[i * row_numel_ + k];
...@@ -427,43 +428,23 @@ class AdamOpKernel : public framework::OpKernel<T> { ...@@ -427,43 +428,23 @@ class AdamOpKernel : public framework::OpKernel<T> {
} }
} }
framework::SelectedRows cpu_grad_merge; framework::SelectedRows tmp_grad_merge;
const framework::SelectedRows* grad_merge_ptr; const framework::SelectedRows* grad_merge_ptr;
if (is_strict_sorted) { if (is_strict_sorted) {
grad_merge_ptr = &grad; grad_merge_ptr = &grad;
} else { } else {
// merge duplicated rows if any. // merge duplicated rows if any.
// The rows of grad_merge have been sorted inside MergeAdd functor // The rows of grad_merge have been sorted inside MergeAdd functor
framework::SelectedRows* grad_merge_var;
scatter::MergeAdd<DeviceContext, T> merge_func; scatter::MergeAdd<DeviceContext, T> merge_func;
if (platform::is_cpu_place(ctx.GetPlace())) {
grad_merge_var = &cpu_grad_merge;
} else {
// FIXME(qiao): GPU also need to fix this
grad_merge_var = const_cast<framework::Scope&>(ctx.scope())
.Var()
->GetMutable<framework::SelectedRows>();
}
merge_func(ctx.template device_context<DeviceContext>(), grad, merge_func(ctx.template device_context<DeviceContext>(), grad,
grad_merge_var, true); &tmp_grad_merge, true);
grad_merge_ptr = grad_merge_var; grad_merge_ptr = &tmp_grad_merge;
} }
auto& grad_merge = *grad_merge_ptr; auto& grad_merge = *grad_merge_ptr;
auto& grad_tensor = grad_merge.value(); auto& grad_tensor = grad_merge.value();
const T* grad_data = grad_tensor.template data<T>(); const T* grad_data = grad_tensor.template data<T>();
const int64_t* rows = nullptr; const int64_t* rows = grad_merge.rows().Data(ctx.GetPlace());
// When compiled without CUDA, the CUDAData() interface should not be
// provided.
#if defined(PADDLE_WITH_CUDA)
if (platform::is_gpu_place(ctx.GetPlace())) {
rows = grad_merge.rows().CUDAData(ctx.GetPlace());
} else {
#endif
rows = grad_merge.rows().data();
#if defined(PADDLE_WITH_CUDA)
}
#endif
auto row_numel = grad_tensor.numel() / grad_merge.rows().size(); auto row_numel = grad_tensor.numel() / grad_merge.rows().size();
if (platform::is_cpu_place(ctx.GetPlace())) { if (platform::is_cpu_place(ctx.GetPlace())) {
...@@ -488,7 +469,7 @@ class AdamOpKernel : public framework::OpKernel<T> { ...@@ -488,7 +469,7 @@ class AdamOpKernel : public framework::OpKernel<T> {
} }
} }
#ifndef _WIN32 #ifndef _WIN32
else if (FLAGS_inner_op_parallelism > 1 && else if (FLAGS_inner_op_parallelism > 1 && // NOLINT
min_row_size_to_use_multithread > 0 && min_row_size_to_use_multithread > 0 &&
param.dims()[0] > min_row_size_to_use_multithread) { param.dims()[0] > min_row_size_to_use_multithread) {
VLOG(3) << "use multi thread, inner_op_parallelism=" VLOG(3) << "use multi thread, inner_op_parallelism="
...@@ -516,11 +497,11 @@ class AdamOpKernel : public framework::OpKernel<T> { ...@@ -516,11 +497,11 @@ class AdamOpKernel : public framework::OpKernel<T> {
for (int i = 0; i < FLAGS_inner_op_parallelism; ++i) { for (int i = 0; i < FLAGS_inner_op_parallelism; ++i) {
int64_t start = i * line_in_each_thread; int64_t start = i * line_in_each_thread;
int64_t end = (i + 1) * line_in_each_thread; int64_t end = (i + 1) * line_in_each_thread;
if (start >= param_row_count) { if (start >= static_cast<int64_t>(param_row_count)) {
break; break;
} }
if (end > param_row_count) { if (end > static_cast<int64_t>(param_row_count)) {
end = param_row_count; end = static_cast<int64_t>(param_row_count);
} }
fs.push_back( fs.push_back(
framework::Async([&functor, &row_id_to_grad_row_offset, framework::Async([&functor, &row_id_to_grad_row_offset,
...@@ -545,8 +526,8 @@ class AdamOpKernel : public framework::OpKernel<T> { ...@@ -545,8 +526,8 @@ class AdamOpKernel : public framework::OpKernel<T> {
} }
for (size_t i = 0; i < fs.size(); ++i) fs[i].wait(); for (size_t i = 0; i < fs.size(); ++i) fs[i].wait();
} }
#endif // !_WIN32 #endif // !_WIN32
else { else { // NOLINT
functor(param.numel()); functor(param.numel());
} }
} else if (platform::is_gpu_place(ctx.GetPlace())) { } else if (platform::is_gpu_place(ctx.GetPlace())) {
......
...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and ...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#pragma once #pragma once
#include <memory>
#include <string> #include <string>
#include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
...@@ -69,6 +70,7 @@ class MomentumOp : public framework::OperatorWithKernel { ...@@ -69,6 +70,7 @@ class MomentumOp : public framework::OperatorWithKernel {
ctx->SetOutputDim("ParamOut", param_dim); ctx->SetOutputDim("ParamOut", param_dim);
ctx->SetOutputDim("VelocityOut", param_dim); ctx->SetOutputDim("VelocityOut", param_dim);
} }
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
auto input_data_type = framework::GetDataTypeOfVar(ctx.InputVar("Param")); auto input_data_type = framework::GetDataTypeOfVar(ctx.InputVar("Param"));
...@@ -351,23 +353,14 @@ class MomentumOpKernel : public framework::OpKernel<T> { ...@@ -351,23 +353,14 @@ class MomentumOpKernel : public framework::OpKernel<T> {
VLOG(3) << "Grad SelectedRows contains no data!"; VLOG(3) << "Grad SelectedRows contains no data!";
return; return;
} }
auto* merged_grad = const_cast<framework::Scope&>(ctx.scope())
.Var() framework::SelectedRows tmp_merged_grad;
->GetMutable<framework::SelectedRows>(); framework::SelectedRows* merged_grad = &tmp_merged_grad;
math::scatter::MergeAdd<DeviceContext, T> merge_func; math::scatter::MergeAdd<DeviceContext, T> merge_func;
merge_func(ctx.template device_context<DeviceContext>(), *grad, merge_func(ctx.template device_context<DeviceContext>(), *grad,
merged_grad); merged_grad);
const int64_t* rows = nullptr; const int64_t* rows = merged_grad->rows().Data(ctx.GetPlace());
#ifdef PADDLE_WITH_CUDA
if (platform::is_gpu_place(ctx.GetPlace())) {
rows = merged_grad->rows().CUDAData(ctx.GetPlace());
} else {
#endif
rows = merged_grad->rows().data();
#ifdef PADDLE_WITH_CUDA
}
#endif
int64_t row_numel = int64_t row_numel =
merged_grad->value().numel() / merged_grad->rows().size(); merged_grad->value().numel() / merged_grad->rows().size();
platform::ForRange<DeviceContext> for_range( platform::ForRange<DeviceContext> for_range(
......
...@@ -216,24 +216,14 @@ class RmspropOpKernel : public framework::OpKernel<T> { ...@@ -216,24 +216,14 @@ class RmspropOpKernel : public framework::OpKernel<T> {
} }
} else if (grad_var->IsType<framework::SelectedRows>()) { } else if (grad_var->IsType<framework::SelectedRows>()) {
auto &grad = grad_var->Get<framework::SelectedRows>(); auto &grad = grad_var->Get<framework::SelectedRows>();
auto *merged_grad = const_cast<framework::Scope &>(ctx.scope()) framework::SelectedRows tmp_merged_grad;
.Var() framework::SelectedRows *merged_grad = &tmp_merged_grad;
->GetMutable<framework::SelectedRows>();
math::scatter::MergeAdd<DeviceContext, T> merge_func; math::scatter::MergeAdd<DeviceContext, T> merge_func;
merge_func(dev_ctx, grad, merged_grad); merge_func(dev_ctx, grad, merged_grad);
platform::ForRange<DeviceContext> for_range(dev_ctx, limit); platform::ForRange<DeviceContext> for_range(dev_ctx, limit);
const int64_t *rows; const int64_t *rows = merged_grad->rows().Data(ctx.GetPlace());
#ifdef PADDLE_WITH_CUDA
if (platform::is_gpu_place(ctx.GetPlace())) {
rows = merged_grad->rows().CUDAData(ctx.GetPlace());
} else {
#endif
rows = merged_grad->rows().data();
#ifdef PADDLE_WITH_CUDA
}
#endif
auto &merged_tensor = merged_grad->value(); auto &merged_tensor = merged_grad->value();
int64_t row_count = merged_grad->rows().size(); int64_t row_count = merged_grad->rows().size();
int64_t row_numel = merged_tensor.numel() / row_count; int64_t row_numel = merged_tensor.numel() / row_count;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册