提交 39917a13 编写于 作者: E eclipsycn 提交者: GitHub

Update mul_arm_func.h

上级 1c1aaa00
...@@ -22,7 +22,8 @@ namespace operators { ...@@ -22,7 +22,8 @@ namespace operators {
// 1、如果x,y维度都是2维, // 1、如果x,y维度都是2维,
// x = [[1,2], y = [[5,6], // x = [[1,2], y = [[5,6],
// [3,4]] [7,8]] // [3,4]] [7,8]]
// 运算结果为正常矩阵相乘。结果 out = [[1*5+2*7,1*6+2*8],[3*5+4*7, 3*6+4*8]] // 运算结果为正常矩阵相乘。结果 out =
// [[1*5+2*7,1*6+2*8],[3*5+4*7, 3*6+4*8]]
// //
// 2、如果x的维度大于2或者y的维度大于2,x的维度(2,3,4) ,y的维度(4,1,2) // 2、如果x的维度大于2或者y的维度大于2,x的维度(2,3,4) ,y的维度(4,1,2)
// x = [[[1,2,3,4], // x = [[[1,2,3,4],
...@@ -35,8 +36,8 @@ namespace operators { ...@@ -35,8 +36,8 @@ namespace operators {
// [[3,4]], // [[3,4]],
// [[5,6]], // [[5,6]],
// [[7,8]]] // [[7,8]]]
// 那么就需要借助x_num_col_dims和y_num_col_dims将x和y的维度转换为2维 // 需要借助x_num_col_dims和y_num_col_dims将x和y的维度转换为2维
// 从模型中读到参数,x_num_col_dims = 2,y_num_col_dims = 1,左开右闭 // 从模型中读到参数,x_num_col_dims = 2,y_num_col_dims = 1,左开右闭
// (1) 将x = (2,3,4)的index [0,x_num_col_dims)部分2,3相乘,得到6, // (1) 将x = (2,3,4)的index [0,x_num_col_dims)部分2,3相乘,得到6,
// [x_num_col_dims,xdim.size())部分4相乘,得到4, // [x_num_col_dims,xdim.size())部分4相乘,得到4,
// 将Tensor x的dims重写成(6,4) // 将Tensor x的dims重写成(6,4)
...@@ -50,7 +51,7 @@ namespace operators { ...@@ -50,7 +51,7 @@ namespace operators {
// [1,2,3,4], [7,8]] // [1,2,3,4], [7,8]]
// [2,3,4,5], // [2,3,4,5],
// [3,4,5,6]] // [3,4,5,6]]
// 结果x(6行4列)乘y(4行2列),按1中矩阵相乘,结果out(6行2列)保存在out里。 // 结果x(6行4列)乘y(4行2列),按1中矩阵相乘,结果out(6行2列)
template <typename P> template <typename P>
void MulCompute(const MulParam &param) { void MulCompute(const MulParam &param) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册