diff --git a/paddle/operators/math/selected_rows_functor.h b/paddle/operators/math/selected_rows_functor.h index 8adfca77f6930734b9c5dac43caa979d8705ed2a..eecd5e5362bbb398933a3ca287a2c25997f0bcfd 100644 --- a/paddle/operators/math/selected_rows_functor.h +++ b/paddle/operators/math/selected_rows_functor.h @@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #pragma once +#include "paddle/framework/eigen.h" #include "paddle/framework/selected_rows.h" #include "paddle/platform/device_context.h" @@ -70,13 +71,13 @@ struct Add { const framework::SelectedRows& input1, const framework::SelectedRows& input2, framework::SelectedRows* out) { - out->set_rows(input1->rows()); - out->set_height(input1->height()); - out->mutable_value()->mutable_data(input1->value().dims(), + out->set_rows(input1.rows()); + out->set_height(input1.height()); + out->mutable_value()->mutable_data(input1.value().dims(), context.GetPlace()); auto e_out = framework::EigenVector::Flatten(*(out->mutable_value())); - auto e_in1 = framework::EigenVector::Flatten(input1->value()); - auto e_in2 = framework::EigenVector::Flatten(input2->value()); + auto e_in1 = framework::EigenVector::Flatten(input1.value()); + auto e_in2 = framework::EigenVector::Flatten(input2.value()); e_out.device(*context.eigen_device()) = e_in1 + e_in2; } }; @@ -87,13 +88,13 @@ struct Mul { const framework::SelectedRows& input1, const framework::SelectedRows& input2, framework::SelectedRows* out) { - out->set_rows(input1->rows()); - out->set_height(input1->height()); - out->mutable_value()->mutable_data(input1->value().dims(), + out->set_rows(input1.rows()); + out->set_height(input1.height()); + out->mutable_value()->mutable_data(input1.value().dims(), context.GetPlace()); auto e_out = framework::EigenVector::Flatten(*(out->mutable_value())); - auto e_in1 = framework::EigenVector::Flatten(input1->value()); - auto e_in2 = framework::EigenVector::Flatten(input2->value()); + auto e_in1 = framework::EigenVector::Flatten(input1.value()); + auto e_in2 = framework::EigenVector::Flatten(input2.value()); e_out.device(*context.eigen_device()) = e_in1 * e_in2; } };