提交 5f99ae90 编写于 作者: P peterzhang2029

refine notation in bilinear_tensor_product_op.h

上级 5cf82041
...@@ -27,10 +27,6 @@ template <typename T, int MajorType = Eigen::RowMajor, ...@@ -27,10 +27,6 @@ template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex> typename IndexType = Eigen::DenseIndex>
using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>; using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;
template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenVector = framework::EigenVector<T, MajorType, IndexType>;
template <typename Place, typename T> template <typename Place, typename T>
class BilinearTensorProductKernel : public framework::OpKernel<T> { class BilinearTensorProductKernel : public framework::OpKernel<T> {
public: public:
...@@ -49,7 +45,9 @@ class BilinearTensorProductKernel : public framework::OpKernel<T> { ...@@ -49,7 +45,9 @@ class BilinearTensorProductKernel : public framework::OpKernel<T> {
auto weight_dims = weight->dims(); auto weight_dims = weight->dims();
auto place = ctx.GetEigenDevice<Place>(); auto place = ctx.GetEigenDevice<Place>();
// Create the intermediate variables. // Create the intermediate variable to caculate the result of
// Input(X) multiplied by Input(Weight_i), the formula is:
// left_mul = X Weight_i.
Tensor left_mul; Tensor left_mul;
left_mul.mutable_data<T>(framework::make_ddim({batch_size, weight_dims[2]}), left_mul.mutable_data<T>(framework::make_ddim({batch_size, weight_dims[2]}),
ctx.GetPlace()); ctx.GetPlace());
...@@ -95,11 +93,13 @@ class BilinearTensorProductGradKernel : public framework::OpKernel<T> { ...@@ -95,11 +93,13 @@ class BilinearTensorProductGradKernel : public framework::OpKernel<T> {
auto d_out_mat = EigenMatrix<T>::From(*d_out); auto d_out_mat = EigenMatrix<T>::From(*d_out);
auto place = ctx.GetEigenDevice<Place>(); auto place = ctx.GetEigenDevice<Place>();
// Create the intermediate variables for gradient. // Create the intermediate variable to caculate the Output(Y@Grad).
Tensor x_scale; Tensor x_scale;
x_scale.mutable_data<T>(framework::make_ddim({batch_size, weight_dims[1]}), x_scale.mutable_data<T>(framework::make_ddim({batch_size, weight_dims[1]}),
ctx.GetPlace()); ctx.GetPlace());
auto x_scale_mat = EigenMatrix<T>::From(x_scale); auto x_scale_mat = EigenMatrix<T>::From(x_scale);
// Create the intermediate variable to caculate the Output(X@Grad).
Tensor y_scale; Tensor y_scale;
y_scale.mutable_data<T>(framework::make_ddim({batch_size, weight_dims[2]}), y_scale.mutable_data<T>(framework::make_ddim({batch_size, weight_dims[2]}),
ctx.GetPlace()); ctx.GetPlace());
...@@ -107,19 +107,19 @@ class BilinearTensorProductGradKernel : public framework::OpKernel<T> { ...@@ -107,19 +107,19 @@ class BilinearTensorProductGradKernel : public framework::OpKernel<T> {
math::SetConstant<Place, T> set_zero; math::SetConstant<Place, T> set_zero;
// Set X@Grad be zero at first. // Set Output(X@Grad) be zero.
if (d_x) { if (d_x) {
d_x->mutable_data<T>(ctx.GetPlace()); d_x->mutable_data<T>(ctx.GetPlace());
set_zero(ctx.device_context(), d_x, static_cast<T>(0)); set_zero(ctx.device_context(), d_x, static_cast<T>(0));
} }
// Set Y@Grad be zero at first. // Set Output(Y@Grad) be zero.
if (d_y) { if (d_y) {
d_y->mutable_data<T>(ctx.GetPlace()); d_y->mutable_data<T>(ctx.GetPlace());
set_zero(ctx.device_context(), d_y, static_cast<T>(0)); set_zero(ctx.device_context(), d_y, static_cast<T>(0));
} }
// Caculate the X@Grad and Y@Grad. // Caculate the Output(X@Grad) and Output(Y@Grad).
if (d_x || d_y) { if (d_x || d_y) {
Eigen::DSizes<int, 2> bcast_for_x(1, weight_dims[2]); Eigen::DSizes<int, 2> bcast_for_x(1, weight_dims[2]);
Eigen::DSizes<int, 2> bcast_for_y(1, weight_dims[1]); Eigen::DSizes<int, 2> bcast_for_y(1, weight_dims[1]);
...@@ -150,7 +150,7 @@ class BilinearTensorProductGradKernel : public framework::OpKernel<T> { ...@@ -150,7 +150,7 @@ class BilinearTensorProductGradKernel : public framework::OpKernel<T> {
} }
} }
// Caculate the gradient of Weight. // Caculate the gradient of Input(Weight).
if (d_weight) { if (d_weight) {
d_weight->mutable_data<T>(ctx.GetPlace()); d_weight->mutable_data<T>(ctx.GetPlace());
Eigen::DSizes<int, 2> bcast_for_weight(1, weight_dims[1]); Eigen::DSizes<int, 2> bcast_for_weight(1, weight_dims[1]);
...@@ -169,7 +169,7 @@ class BilinearTensorProductGradKernel : public framework::OpKernel<T> { ...@@ -169,7 +169,7 @@ class BilinearTensorProductGradKernel : public framework::OpKernel<T> {
} }
} }
// Caculate the gradient of Bias. // Caculate the gradient of Input(Bias).
if (d_bias) { if (d_bias) {
d_bias->mutable_data<T>(ctx.GetPlace()); d_bias->mutable_data<T>(ctx.GetPlace());
auto d_bias_mat = EigenMatrix<T>::From(*d_bias); auto d_bias_mat = EigenMatrix<T>::From(*d_bias);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册