提交 5f99ae90 编写于 作者: P peterzhang2029

refine notation in bilinear_tensor_product_op.h

上级 5cf82041
......@@ -27,10 +27,6 @@ template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;
template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenVector = framework::EigenVector<T, MajorType, IndexType>;
template <typename Place, typename T>
class BilinearTensorProductKernel : public framework::OpKernel<T> {
public:
......@@ -49,7 +45,9 @@ class BilinearTensorProductKernel : public framework::OpKernel<T> {
auto weight_dims = weight->dims();
auto place = ctx.GetEigenDevice<Place>();
// Create the intermediate variables.
// Create the intermediate variable to caculate the result of
// Input(X) multiplied by Input(Weight_i), the formula is:
// left_mul = X Weight_i.
Tensor left_mul;
left_mul.mutable_data<T>(framework::make_ddim({batch_size, weight_dims[2]}),
ctx.GetPlace());
......@@ -95,11 +93,13 @@ class BilinearTensorProductGradKernel : public framework::OpKernel<T> {
auto d_out_mat = EigenMatrix<T>::From(*d_out);
auto place = ctx.GetEigenDevice<Place>();
// Create the intermediate variables for gradient.
// Create the intermediate variable to caculate the Output(Y@Grad).
Tensor x_scale;
x_scale.mutable_data<T>(framework::make_ddim({batch_size, weight_dims[1]}),
ctx.GetPlace());
auto x_scale_mat = EigenMatrix<T>::From(x_scale);
// Create the intermediate variable to caculate the Output(X@Grad).
Tensor y_scale;
y_scale.mutable_data<T>(framework::make_ddim({batch_size, weight_dims[2]}),
ctx.GetPlace());
......@@ -107,19 +107,19 @@ class BilinearTensorProductGradKernel : public framework::OpKernel<T> {
math::SetConstant<Place, T> set_zero;
// Set X@Grad be zero at first.
// Set Output(X@Grad) be zero.
if (d_x) {
d_x->mutable_data<T>(ctx.GetPlace());
set_zero(ctx.device_context(), d_x, static_cast<T>(0));
}
// Set Y@Grad be zero at first.
// Set Output(Y@Grad) be zero.
if (d_y) {
d_y->mutable_data<T>(ctx.GetPlace());
set_zero(ctx.device_context(), d_y, static_cast<T>(0));
}
// Caculate the X@Grad and Y@Grad.
// Caculate the Output(X@Grad) and Output(Y@Grad).
if (d_x || d_y) {
Eigen::DSizes<int, 2> bcast_for_x(1, weight_dims[2]);
Eigen::DSizes<int, 2> bcast_for_y(1, weight_dims[1]);
......@@ -150,7 +150,7 @@ class BilinearTensorProductGradKernel : public framework::OpKernel<T> {
}
}
// Caculate the gradient of Weight.
// Caculate the gradient of Input(Weight).
if (d_weight) {
d_weight->mutable_data<T>(ctx.GetPlace());
Eigen::DSizes<int, 2> bcast_for_weight(1, weight_dims[1]);
......@@ -169,7 +169,7 @@ class BilinearTensorProductGradKernel : public framework::OpKernel<T> {
}
}
// Caculate the gradient of Bias.
// Caculate the gradient of Input(Bias).
if (d_bias) {
d_bias->mutable_data<T>(ctx.GetPlace());
auto d_bias_mat = EigenMatrix<T>::From(*d_bias);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册