From 74b122889cbce2aa3add92784d0b4a621abfdf45 Mon Sep 17 00:00:00 2001 From: typhoonzero Date: Wed, 27 Dec 2017 21:08:40 +0800 Subject: [PATCH] wip --- paddle/operators/math/selected_rows_functor.h | 21 ++++++++++--------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/paddle/operators/math/selected_rows_functor.h b/paddle/operators/math/selected_rows_functor.h index 8adfca77f6..eecd5e5362 100644 --- a/paddle/operators/math/selected_rows_functor.h +++ b/paddle/operators/math/selected_rows_functor.h @@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #pragma once +#include "paddle/framework/eigen.h" #include "paddle/framework/selected_rows.h" #include "paddle/platform/device_context.h" @@ -70,13 +71,13 @@ struct Add { const framework::SelectedRows& input1, const framework::SelectedRows& input2, framework::SelectedRows* out) { - out->set_rows(input1->rows()); - out->set_height(input1->height()); - out->mutable_value()->mutable_data(input1->value().dims(), + out->set_rows(input1.rows()); + out->set_height(input1.height()); + out->mutable_value()->mutable_data(input1.value().dims(), context.GetPlace()); auto e_out = framework::EigenVector::Flatten(*(out->mutable_value())); - auto e_in1 = framework::EigenVector::Flatten(input1->value()); - auto e_in2 = framework::EigenVector::Flatten(input2->value()); + auto e_in1 = framework::EigenVector::Flatten(input1.value()); + auto e_in2 = framework::EigenVector::Flatten(input2.value()); e_out.device(*context.eigen_device()) = e_in1 + e_in2; } }; @@ -87,13 +88,13 @@ struct Mul { const framework::SelectedRows& input1, const framework::SelectedRows& input2, framework::SelectedRows* out) { - out->set_rows(input1->rows()); - out->set_height(input1->height()); - out->mutable_value()->mutable_data(input1->value().dims(), + out->set_rows(input1.rows()); + out->set_height(input1.height()); + out->mutable_value()->mutable_data(input1.value().dims(), context.GetPlace()); auto e_out = framework::EigenVector::Flatten(*(out->mutable_value())); - auto e_in1 = framework::EigenVector::Flatten(input1->value()); - auto e_in2 = framework::EigenVector::Flatten(input2->value()); + auto e_in1 = framework::EigenVector::Flatten(input1.value()); + auto e_in2 = framework::EigenVector::Flatten(input2.value()); e_out.device(*context.eigen_device()) = e_in1 * e_in2; } }; -- GitLab