diff --git a/src/operators/kernel/central-arm-func/mul_arm_func.h b/src/operators/kernel/central-arm-func/mul_arm_func.h index aff30b915a7b7a4f23e88cdacaa8b04940af194c..df45d9ec660606d800f785564cdbc787789a160d 100644 --- a/src/operators/kernel/central-arm-func/mul_arm_func.h +++ b/src/operators/kernel/central-arm-func/mul_arm_func.h @@ -22,7 +22,8 @@ namespace operators { // 1、如果x,y维度都是2维, // x = [[1,2], y = [[5,6], // [3,4]] [7,8]] -// 运算结果为正常矩阵相乘。结果 out = [[1*5+2*7,1*6+2*8],[3*5+4*7, 3*6+4*8]] +// 运算结果为正常矩阵相乘。结果 out = +// [[1*5+2*7,1*6+2*8],[3*5+4*7, 3*6+4*8]] // // 2、如果x的维度大于2或者y的维度大于2,x的维度(2,3,4) ,y的维度(4,1,2) // x = [[[1,2,3,4], @@ -35,8 +36,8 @@ namespace operators { // [[3,4]], // [[5,6]], // [[7,8]]] -// 那么就需要借助x_num_col_dims和y_num_col_dims将x和y的维度转换为2维 -// 从模型中读到参数,x_num_col_dims = 2,y_num_col_dims = 1,左开右闭 +// 需要借助x_num_col_dims和y_num_col_dims将x和y的维度转换为2维 +// 从模型中读到参数,x_num_col_dims = 2,y_num_col_dims = 1,左开右闭 // (1) 将x = (2,3,4)的index [0,x_num_col_dims)部分2,3相乘,得到6, // [x_num_col_dims,xdim.size())部分4相乘,得到4, // 将Tensor x的dims重写成(6,4) @@ -50,7 +51,7 @@ namespace operators { // [1,2,3,4], [7,8]] // [2,3,4,5], // [3,4,5,6]] -// 结果x(6行4列)乘y(4行2列),按1中矩阵相乘,结果out(6行2列)保存在out里。 +// 结果x(6行4列)乘y(4行2列),按1中矩阵相乘,结果out(6行2列) template void MulCompute(const MulParam ¶m) {