提交 96d55009 编写于 作者: Q Qiao Longfei

optimize code

上级 748ee35c
...@@ -275,7 +275,7 @@ struct MergeAdd<platform::CPUDeviceContext, T> { ...@@ -275,7 +275,7 @@ struct MergeAdd<platform::CPUDeviceContext, T> {
} }
const framework::SelectedRows* has_value_input = nullptr; const framework::SelectedRows* has_value_input = nullptr;
for (auto* in : inputs) { for (auto* in : inputs) {
if (!in->rows().empty()) { if (in->rows().size() > 0) {
has_value_input = in; has_value_input = in;
break; break;
} }
...@@ -289,7 +289,7 @@ struct MergeAdd<platform::CPUDeviceContext, T> { ...@@ -289,7 +289,7 @@ struct MergeAdd<platform::CPUDeviceContext, T> {
framework::SelectedRows& out = *output; framework::SelectedRows& out = *output;
std::set<int64_t> merged_row_set; std::set<int64_t> merged_row_set;
for (auto* input : inputs) { for (auto* input : inputs) {
if (input->rows().empty()) { if (input->rows().size() == 0) {
continue; continue;
} }
PADDLE_ENFORCE_EQ(input_width, input->value().dims()[1], PADDLE_ENFORCE_EQ(input_width, input->value().dims()[1],
...@@ -319,7 +319,7 @@ struct MergeAdd<platform::CPUDeviceContext, T> { ...@@ -319,7 +319,7 @@ struct MergeAdd<platform::CPUDeviceContext, T> {
auto blas = math::GetBlas<platform::CPUDeviceContext, T>(context); auto blas = math::GetBlas<platform::CPUDeviceContext, T>(context);
for (auto* input : inputs) { for (auto* input : inputs) {
if (input->rows().empty()) { if (input->rows().size() == 0) {
continue; continue;
} }
auto* input_data = input->value().data<T>(); auto* input_data = input->value().data<T>();
......
...@@ -311,7 +311,7 @@ struct MergeAdd<platform::CUDADeviceContext, T> { ...@@ -311,7 +311,7 @@ struct MergeAdd<platform::CUDADeviceContext, T> {
} }
const framework::SelectedRows* has_value_input = nullptr; const framework::SelectedRows* has_value_input = nullptr;
for (auto* in : inputs) { for (auto* in : inputs) {
if (!in->rows().empty()) { if (in->rows().size() > 0) {
has_value_input = in; has_value_input = in;
break; break;
} }
...@@ -325,7 +325,7 @@ struct MergeAdd<platform::CUDADeviceContext, T> { ...@@ -325,7 +325,7 @@ struct MergeAdd<platform::CUDADeviceContext, T> {
framework::SelectedRows& out = *output; framework::SelectedRows& out = *output;
std::set<int64_t> merged_row_set; std::set<int64_t> merged_row_set;
for (auto* input : inputs) { for (auto* input : inputs) {
if (input->rows().empty()) { if (input->rows().size() == 0) {
continue; continue;
} }
PADDLE_ENFORCE_EQ(input_width, input->value().dims()[1], PADDLE_ENFORCE_EQ(input_width, input->value().dims()[1],
...@@ -355,7 +355,7 @@ struct MergeAdd<platform::CUDADeviceContext, T> { ...@@ -355,7 +355,7 @@ struct MergeAdd<platform::CUDADeviceContext, T> {
dim3 threads(block_size, 1); dim3 threads(block_size, 1);
for (auto* input : inputs) { for (auto* input : inputs) {
if (input->rows().empty()) { if (input->rows().size() == 0) {
continue; continue;
} }
auto* input_data = input->value().data<T>(); auto* input_data = input->value().data<T>();
......
...@@ -107,7 +107,7 @@ class SumKernel : public framework::OpKernel<T> { ...@@ -107,7 +107,7 @@ class SumKernel : public framework::OpKernel<T> {
} else { } else {
for (auto &in_var : in_vars) { for (auto &in_var : in_vars) {
auto &in = in_var->Get<SelectedRows>(); auto &in = in_var->Get<SelectedRows>();
if (!in.rows().empty()) { if (in.rows().size() > 0) {
inputs.push_back(&in_var->Get<SelectedRows>()); inputs.push_back(&in_var->Get<SelectedRows>());
} }
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册