提交 36e26a53 编写于 作者: Q Qiao Longfei 提交者: qingqing01

Optimize bilinear tensor product op (#14485)

* optimize bilinear_tensor_product
* add set zero to set grad to 0.
上级 4ec9de01
...@@ -70,7 +70,7 @@ class BilinearTensorProductKernel : public framework::OpKernel<T> { ...@@ -70,7 +70,7 @@ class BilinearTensorProductKernel : public framework::OpKernel<T> {
if (bias) { if (bias) {
auto bias_vec = EigenMatrix<T>::From(*bias); auto bias_vec = EigenMatrix<T>::From(*bias);
Eigen::DSizes<int, 2> bcast(batch_size, 1); Eigen::DSizes<int, 2> bcast(batch_size, 1);
output_mat.device(place) = bias_vec.broadcast(bcast) + output_mat; output_mat.device(place) = bias_vec.broadcast(bcast).eval() + output_mat;
} }
} }
}; };
...@@ -99,13 +99,13 @@ class BilinearTensorProductGradKernel : public framework::OpKernel<T> { ...@@ -99,13 +99,13 @@ class BilinearTensorProductGradKernel : public framework::OpKernel<T> {
auto d_out_mat = EigenMatrix<T>::From(*d_out); auto d_out_mat = EigenMatrix<T>::From(*d_out);
auto& place = *ctx.template device_context<DeviceContext>().eigen_device(); auto& place = *ctx.template device_context<DeviceContext>().eigen_device();
auto& dev_ctx = ctx.template device_context<DeviceContext>(); auto& dev_ctx = ctx.template device_context<DeviceContext>();
// Create the intermediate variable to caculate the Output(Y@Grad). // Create the intermediate variable to calculate the Output(Y@Grad).
Tensor x_scale; Tensor x_scale;
x_scale.mutable_data<T>(framework::make_ddim({batch_size, x_dim}), x_scale.mutable_data<T>(framework::make_ddim({batch_size, x_dim}),
ctx.GetPlace()); ctx.GetPlace());
auto x_scale_mat = EigenMatrix<T>::From(x_scale); auto x_scale_mat = EigenMatrix<T>::From(x_scale);
// Create the intermediate variable to caculate the Output(X@Grad). // Create the intermediate variable to calculate the Output(X@Grad).
Tensor y_scale; Tensor y_scale;
y_scale.mutable_data<T>(framework::make_ddim({batch_size, y_dim}), y_scale.mutable_data<T>(framework::make_ddim({batch_size, y_dim}),
ctx.GetPlace()); ctx.GetPlace());
...@@ -113,65 +113,64 @@ class BilinearTensorProductGradKernel : public framework::OpKernel<T> { ...@@ -113,65 +113,64 @@ class BilinearTensorProductGradKernel : public framework::OpKernel<T> {
math::SetConstant<DeviceContext, T> set_zero; math::SetConstant<DeviceContext, T> set_zero;
// Set Output(X@Grad) be zero.
if (d_x) { if (d_x) {
d_x->mutable_data<T>(ctx.GetPlace()); d_x->mutable_data<T>(ctx.GetPlace());
set_zero(dev_ctx, d_x, static_cast<T>(0)); set_zero(dev_ctx, d_x, static_cast<T>(0));
} }
// Set Output(Y@Grad) be zero.
if (d_y) { if (d_y) {
d_y->mutable_data<T>(ctx.GetPlace()); d_y->mutable_data<T>(ctx.GetPlace());
set_zero(dev_ctx, d_y, static_cast<T>(0)); set_zero(dev_ctx, d_y, static_cast<T>(0));
} }
if (d_weight) {
d_weight->mutable_data<T>(ctx.GetPlace());
}
auto blas = math::GetBlas<DeviceContext, T>(ctx); auto blas = math::GetBlas<DeviceContext, T>(ctx);
// Caculate the Output(X@Grad) and Output(Y@Grad). // Caculate the Output(X@Grad) and Output(Y@Grad).
if (d_x || d_y) { if (d_x || d_y || d_weight) {
Eigen::DSizes<int, 2> bcast_for_x(1, y_dim); Eigen::DSizes<int, 2> bcast_for_x(1, y_dim);
Eigen::DSizes<int, 2> bcast_for_y(1, x_dim); Eigen::DSizes<int, 2> bcast_for_y(1, x_dim);
Eigen::DSizes<int, 2> bcast_for_weight(1, x_dim);
for (int i = 0; i < out_dim; ++i) { for (int i = 0; i < out_dim; ++i) {
Tensor weight_i = weight->Slice(i, i + 1).Resize( Tensor weight_i = weight->Slice(i, i + 1).Resize(
framework::make_ddim({x_dim, y_dim})); framework::make_ddim({x_dim, y_dim}));
auto output_vec = d_out_mat.chip(i, 1); auto output_vec = d_out_mat.chip(i, 1);
if (d_x) { if (d_x) {
y_scale_mat.device(place) = y_scale_mat.device(place) =
output_vec.reshape(Eigen::DSizes<int, 2>(batch_size, 1)) output_vec.reshape(Eigen::DSizes<int, 2>(batch_size, 1))
.broadcast(bcast_for_x) * .broadcast(bcast_for_x)
.eval() *
y_mat; y_mat;
blas.GEMM(CblasNoTrans, CblasTrans, batch_size, x_dim, y_dim, 1, blas.GEMM(CblasNoTrans, CblasTrans, batch_size, x_dim, y_dim, 1,
y_scale.data<T>(), weight_i.data<T>(), 1, d_x->data<T>()); y_scale.data<T>(), weight_i.data<T>(), 1, d_x->data<T>());
} }
if (d_y) {
x_scale_mat.device(place) = if (d_y || d_weight) {
auto output_vec_y =
output_vec.reshape(Eigen::DSizes<int, 2>(batch_size, 1)) output_vec.reshape(Eigen::DSizes<int, 2>(batch_size, 1))
.broadcast(bcast_for_y) * .broadcast(bcast_for_y)
x_mat; .eval();
blas.GEMM(CblasNoTrans, CblasNoTrans, batch_size, y_dim, x_dim, 1, x_scale_mat.device(place) = output_vec_y * x_mat;
x_scale.data<T>(), weight_i.data<T>(), 1, d_y->data<T>()); if (d_y) {
blas.GEMM(CblasNoTrans, CblasNoTrans, batch_size, y_dim, x_dim, 1,
x_scale.data<T>(), weight_i.data<T>(), 1, d_y->data<T>());
}
if (d_weight) {
Tensor d_weight_i = d_weight->Slice(i, i + 1).Resize(
framework::make_ddim({x_dim, y_dim}));
blas.GEMM(CblasTrans, CblasNoTrans, x_dim, y_dim, batch_size, 1,
x_scale.data<T>(), y->data<T>(), 0, d_weight_i.data<T>());
}
} }
} }
} }
// Caculate the gradient of Input(Weight). // calculate the gradient of Input(Bias).
if (d_weight) {
d_weight->mutable_data<T>(ctx.GetPlace());
Eigen::DSizes<int, 2> bcast_for_weight(1, x_dim);
for (int i = 0; i < out_dim; ++i) {
Tensor d_weight_i = d_weight->Slice(i, i + 1).Resize(
framework::make_ddim({x_dim, y_dim}));
auto output_vec = d_out_mat.chip(i, 1);
x_scale_mat.device(place) =
output_vec.reshape(Eigen::DSizes<int, 2>(batch_size, 1))
.broadcast(bcast_for_weight) *
x_mat;
blas.GEMM(CblasTrans, CblasNoTrans, x_dim, y_dim, batch_size, 1,
x_scale.data<T>(), y->data<T>(), 0, d_weight_i.data<T>());
}
}
// Caculate the gradient of Input(Bias).
if (d_bias) { if (d_bias) {
d_bias->mutable_data<T>(ctx.GetPlace()); d_bias->mutable_data<T>(ctx.GetPlace());
auto d_bias_mat = framework::EigenVector<T>::Flatten(*d_bias); auto d_bias_mat = framework::EigenVector<T>::Flatten(*d_bias);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册