未验证 提交 9fdc5896 编写于 作者: qnqinan's avatar qnqinan 提交者: GitHub

Merge branch 'develop' into develop

...@@ -19,6 +19,40 @@ limitations under the License. */ ...@@ -19,6 +19,40 @@ limitations under the License. */
namespace paddle_mobile { namespace paddle_mobile {
namespace operators { namespace operators {
// 1、如果x,y维度都是2维,
// x = [[1,2], y = [[5,6],
// [3,4]] [7,8]]
// 运算结果为正常矩阵相乘。结果 out =
// [[1*5+2*7,1*6+2*8],[3*5+4*7, 3*6+4*8]]
//
// 2、如果x的维度大于2或者y的维度大于2,x的维度(2,3,4) ,y的维度(4,1,2)
// x = [[[1,2,3,4],
// [2,3,4,5],
// [3,4,5,6]],
// [[1,2,3,4],
// [2,3,4,5],
// [3,4,5,6]]]
// y = [[[1,2]],
// [[3,4]],
// [[5,6]],
// [[7,8]]]
// 需要借助x_num_col_dims和y_num_col_dims将x和y的维度转换为2维
// 从模型中读到参数,x_num_col_dims = 2,y_num_col_dims = 1,左开右闭
// (1) 将x = (2,3,4)的index [0,x_num_col_dims)部分2,3相乘,得到6,
// [x_num_col_dims,xdim.size())部分4相乘,得到4,
// 将Tensor x的dims重写成(6,4)
// (2) 将y = (4,1,2)的index [0,y_num_col_dims)部分4相乘,得到4,
// [y_num_col_dims,ydim.size())部分1,2相乘,得到2,
// 将Tensor y的dims重写成(4,2)
// 并不影响x,y在内存中的分布。
// x = [[1,2,3,4], y = [[1,2],
// [2,3,4,5], [3,4],
// [3,4,5,6], 矩阵乘法 [5,6],
// [1,2,3,4], [7,8]]
// [2,3,4,5],
// [3,4,5,6]]
// 结果x(6行4列)乘y(4行2列),按1中矩阵相乘,结果out(6行2列)
template <typename P> template <typename P>
void MulCompute(const MulParam &param) { void MulCompute(const MulParam &param) {
const Tensor *input_x = param.InputX(); const Tensor *input_x = param.InputX();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册