提交 0321e1f8 编写于 作者: D dangqingqing

Fix bilinear_tensor_product_op in debug mode.

上级 a76b6144
...@@ -63,6 +63,7 @@ class BilinearTensorProductKernel : public framework::OpKernel<T> { ...@@ -63,6 +63,7 @@ class BilinearTensorProductKernel : public framework::OpKernel<T> {
math::gemm<Place, T>(ctx.device_context(), CblasNoTrans, CblasNoTrans, math::gemm<Place, T>(ctx.device_context(), CblasNoTrans, CblasNoTrans,
batch_size, y_dim, x_dim, 1, x->data<T>(), batch_size, y_dim, x_dim, 1, x->data<T>(),
weight_mat.data<T>(), 0, left_mul.data<T>()); weight_mat.data<T>(), 0, left_mul.data<T>());
Eigen::array<int, 2> shape({{static_cast<int>(out->dims()[0]), 1}});
output_col_vec.device(place) = output_col_vec.device(place) =
(left_mul_mat * y_mat).sum(Eigen::DSizes<int, 1>(1)); (left_mul_mat * y_mat).sum(Eigen::DSizes<int, 1>(1));
} }
...@@ -174,7 +175,7 @@ class BilinearTensorProductGradKernel : public framework::OpKernel<T> { ...@@ -174,7 +175,7 @@ class BilinearTensorProductGradKernel : public framework::OpKernel<T> {
// Caculate the gradient of Input(Bias). // Caculate the gradient of Input(Bias).
if (d_bias) { if (d_bias) {
d_bias->mutable_data<T>(ctx.GetPlace()); d_bias->mutable_data<T>(ctx.GetPlace());
auto d_bias_mat = EigenMatrix<T>::From(*d_bias); auto d_bias_mat = framework::EigenVector<T>::Flatten(*d_bias);
d_bias_mat.device(place) = d_out_mat.sum(Eigen::DSizes<int, 1>(0)); d_bias_mat.device(place) = d_out_mat.sum(Eigen::DSizes<int, 1>(0));
} }
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册