提交 74b12288 编写于 作者: T typhoonzero

wip

上级 d48a0e4e
......@@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/framework/eigen.h"
#include "paddle/framework/selected_rows.h"
#include "paddle/platform/device_context.h"
......@@ -70,13 +71,13 @@ struct Add {
const framework::SelectedRows& input1,
const framework::SelectedRows& input2,
framework::SelectedRows* out) {
out->set_rows(input1->rows());
out->set_height(input1->height());
out->mutable_value()->mutable_data<T>(input1->value().dims(),
out->set_rows(input1.rows());
out->set_height(input1.height());
out->mutable_value()->mutable_data<T>(input1.value().dims(),
context.GetPlace());
auto e_out = framework::EigenVector<T>::Flatten(*(out->mutable_value()));
auto e_in1 = framework::EigenVector<T>::Flatten(input1->value());
auto e_in2 = framework::EigenVector<T>::Flatten(input2->value());
auto e_in1 = framework::EigenVector<T>::Flatten(input1.value());
auto e_in2 = framework::EigenVector<T>::Flatten(input2.value());
e_out.device(*context.eigen_device()) = e_in1 + e_in2;
}
};
......@@ -87,13 +88,13 @@ struct Mul {
const framework::SelectedRows& input1,
const framework::SelectedRows& input2,
framework::SelectedRows* out) {
out->set_rows(input1->rows());
out->set_height(input1->height());
out->mutable_value()->mutable_data<T>(input1->value().dims(),
out->set_rows(input1.rows());
out->set_height(input1.height());
out->mutable_value()->mutable_data<T>(input1.value().dims(),
context.GetPlace());
auto e_out = framework::EigenVector<T>::Flatten(*(out->mutable_value()));
auto e_in1 = framework::EigenVector<T>::Flatten(input1->value());
auto e_in2 = framework::EigenVector<T>::Flatten(input2->value());
auto e_in1 = framework::EigenVector<T>::Flatten(input1.value());
auto e_in2 = framework::EigenVector<T>::Flatten(input2.value());
e_out.device(*context.eigen_device()) = e_in1 * e_in2;
}
};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册