提交 52e13225 编写于 作者: Z Zhen Wang

add int8_t gemm and enable MulOp to support int8_t.

上级 1c893a02
...@@ -32,7 +32,7 @@ template <typename Dtype> ...@@ -32,7 +32,7 @@ template <typename Dtype>
vector<string> OperatorBase<Dtype>::GetInputKeys() const { vector<string> OperatorBase<Dtype>::GetInputKeys() const {
auto it = op_input_output_key.find(type_); auto it = op_input_output_key.find(type_);
if (it == op_input_output_key.end()) { if (it == op_input_output_key.end()) {
DLOG << type_ << " has no outputs"; DLOG << type_ << " has no inputs";
return {}; return {};
} }
return it->second.first; return it->second.first;
......
...@@ -338,10 +338,12 @@ inline Print &operator<<(Print &printer, const Tensor &tensor) { ...@@ -338,10 +338,12 @@ inline Print &operator<<(Print &printer, const Tensor &tensor) {
for (int i = 0; i < tensor.numel(); i += stride) { for (int i = 0; i < tensor.numel(); i += stride) {
if (tensor.type() == typeid(float)) { if (tensor.type() == typeid(float)) {
printer << tensor.data<float>()[i] << " "; printer << tensor.data<float>()[i] << " ";
} else if (tensor.type() == typeid(int32_t)) {
printer << tensor.data<int32_t>()[i] << " ";
} else if (tensor.type() == typeid(int64_t)) { } else if (tensor.type() == typeid(int64_t)) {
printer << tensor.data<int64_t>()[i] << " "; printer << tensor.data<int64_t>()[i] << " ";
} else if (tensor.type() == typeid(int8_t)) { } else if (tensor.type() == typeid(int8_t)) {
printer << tensor.data<int8_t>()[i] << " "; printer << static_cast<int32_t>(tensor.data<int8_t>()[i]) << " ";
} }
} }
#endif #endif
......
...@@ -25,12 +25,15 @@ bool MulKernel<CPU, float>::Init(MulParam<CPU> *param) { ...@@ -25,12 +25,15 @@ bool MulKernel<CPU, float>::Init(MulParam<CPU> *param) {
return true; return true;
} }
template <> template <>
void MulKernel<CPU, float>::Compute(const MulParam<CPU> &param) const { void MulKernel<CPU, float>::Compute(const MulParam<CPU> &param) const {
MulCompute<float>(param); MulCompute<float>(param);
param.Out()->set_lod(param.InputX()->lod()); param.Out()->set_lod(param.InputX()->lod());
} }
template class MulKernel<CPU, float>;
} // namespace operators } // namespace operators
} // namespace paddle_mobile } // namespace paddle_mobile
......
...@@ -58,7 +58,7 @@ void MulCompute(const MulParam<CPU> &param) { ...@@ -58,7 +58,7 @@ void MulCompute(const MulParam<CPU> &param) {
const Tensor *input_x = param.InputX(); const Tensor *input_x = param.InputX();
const Tensor *input_y = param.InputY(); const Tensor *input_y = param.InputY();
Tensor *out = param.Out(); Tensor *out = param.Out();
out->mutable_data<float>();
const Tensor x_matrix = const Tensor x_matrix =
input_x->dims().size() > 2 input_x->dims().size() > 2
? framework::ReshapeToMatrix(*input_x, param.XNumColDims()) ? framework::ReshapeToMatrix(*input_x, param.XNumColDims())
...@@ -71,15 +71,21 @@ void MulCompute(const MulParam<CPU> &param) { ...@@ -71,15 +71,21 @@ void MulCompute(const MulParam<CPU> &param) {
if (out_dim.size() != 2) { if (out_dim.size() != 2) {
out->Resize({x_matrix.dims()[0], y_matrix.dims()[1]}); out->Resize({x_matrix.dims()[0], y_matrix.dims()[1]});
} }
if (param.InputX()->type() == typeid(int8_t)) {
out->mutable_data<int32_t>();
math::matmul<int8_t>(x_matrix, false, y_matrix, false,
static_cast<int8_t>(1), out, static_cast<int8_t>(0));
} else {
out->mutable_data<float>();
math::matmul<float>(x_matrix, false, y_matrix, false, static_cast<float>(1), math::matmul<float>(x_matrix, false, y_matrix, false, static_cast<float>(1),
out, static_cast<float>(0)); out, static_cast<float>(0));
}
if (out_dim.size() != 2) { if (out_dim.size() != 2) {
out->Resize(out_dim); out->Resize(out_dim);
} }
} }
template class MulKernel<CPU, float>;
} // namespace operators } // namespace operators
} // namespace paddle_mobile } // namespace paddle_mobile
......
...@@ -13,7 +13,6 @@ See the License for the specific language governing permissions and ...@@ -13,7 +13,6 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#pragma once #pragma once
#include <stdint-gcc.h>
#include <string> #include <string>
#include "common/log.h" #include "common/log.h"
......
...@@ -34,8 +34,10 @@ void Gemm::AddDot6x8(int32_t k, const int8_t *a, const int8_t *b, int32_t *c, ...@@ -34,8 +34,10 @@ void Gemm::AddDot6x8(int32_t k, const int8_t *a, const int8_t *b, int32_t *c,
const int8_t *a_ptr, *b_ptr; const int8_t *a_ptr, *b_ptr;
a_ptr = a; a_ptr = a;
b_ptr = b; b_ptr = b;
int32_t kc1 = k >> 1; int32_t kc1 = k >> 3;
int32_t kc2 = k & 1; int32_t kc2 = k & 7;
int32_t kc3 = kc2 >> 1;
int32_t kc4 = kc2 & 1;
int32_t step = sizeof(int32_t) * ldc; int32_t step = sizeof(int32_t) * ldc;
asm volatile( asm volatile(
// q4-q15: save 48 results // q4-q15: save 48 results
...@@ -57,16 +59,20 @@ void Gemm::AddDot6x8(int32_t k, const int8_t *a, const int8_t *b, int32_t *c, ...@@ -57,16 +59,20 @@ void Gemm::AddDot6x8(int32_t k, const int8_t *a, const int8_t *b, int32_t *c,
"subs %[kc1], %[kc1], #1 \n\t" "subs %[kc1], %[kc1], #1 \n\t"
"blt 1f \n\t" "blt 1f \n\t"
"0: \n\t" "0: \n\t"
"pld [%[a_ptr], #64] \n\t"
"pld [%[b_ptr], #64] \n\t"
"vld1.s8 {d0}, [%[a_ptr]], r0 \n\t" // A col0 "vld1.s8 {d0}, [%[a_ptr]], r0 \n\t" // A col0
"vld1.s8 {d1}, [%[a_ptr]], r0 \n\t" // A col1, q0 used "vld1.s8 {d1}, [%[a_ptr]], r0 \n\t" // A col1, q0
"vld1.s8 {d2-d3}, [%[b_ptr]]! \n\t" // B row0, B row1, q1 used // used
"vld1.s8 {d2-d3}, [%[b_ptr]]! \n\t" // B row0, B
// row1, q1
// used
"vmov.s8 q2, #0 \n\t" // q2 used "vmov.s8 q2, #0 \n\t" // q2 used
"vdup.s8 d6, d0[0] \n\t" "vdup.s8 d6, d0[0] \n\t"
"vdup.s8 d7, d1[0] \n\t" // q3 used "vdup.s8 d7, d1[0] \n\t" // q3 used
"vmlal.s8 q2, d2, d6 \n\t" // A col00 * B row0 "vmlal.s8 q2, d2, d6 \n\t" // A col00 * B
"vmlal.s8 q2, d3, d7 \n\t" // A col10 * B row1, q3 free // row0
"vmlal.s8 q2, d3, d7 \n\t" // A col10 * B
// row1, q3
// free
"vaddw.s16 q4, q4, d4 \n\t" "vaddw.s16 q4, q4, d4 \n\t"
"vaddw.s16 q5, q5, d5 \n\t" // res row 0 "vaddw.s16 q5, q5, d5 \n\t" // res row 0
"vmov.s8 q2, #0 \n\t" "vmov.s8 q2, #0 \n\t"
...@@ -90,7 +96,59 @@ void Gemm::AddDot6x8(int32_t k, const int8_t *a, const int8_t *b, int32_t *c, ...@@ -90,7 +96,59 @@ void Gemm::AddDot6x8(int32_t k, const int8_t *a, const int8_t *b, int32_t *c,
"vmlal.s8 q2, d3, d7 \n\t" "vmlal.s8 q2, d3, d7 \n\t"
"vaddw.s16 q10, q10, d4 \n\t" "vaddw.s16 q10, q10, d4 \n\t"
"vaddw.s16 q11, q11, d5 \n\t" // res row 3 "vaddw.s16 q11, q11, d5 \n\t" // res row 3
"vmov.s8 q2, #0. \n\t"
"vdup.s8 d6, d0[4] \n\t"
"vdup.s8 d7, d1[4] \n\t"
"vmlal.s8 q2, d2, d6 \n\t"
"vmlal.s8 q2, d3, d7 \n\t"
"vaddw.s16 q12, q12, d4 \n\t"
"vaddw.s16 q13, q13, d5 \n\t" // res row 4
"vmov.s8 q2, #0 \n\t"
"vdup.s8 d6, d0[5] \n\t"
"vdup.s8 d7, d1[5] \n\t"
"vmlal.s8 q2, d2, d6 \n\t"
"vmlal.s8 q2, d3, d7 \n\t"
"vaddw.s16 q14, q14, d4 \n\t"
"vaddw.s16 q15, q15, d5 \n\t" // res row 5
"vld1.s8 {d0}, [%[a_ptr]], r0 \n\t" // A col0
"vld1.s8 {d1}, [%[a_ptr]], r0 \n\t" // A col1, q0
// used
"vld1.s8 {d2-d3}, [%[b_ptr]]! \n\t" // B row0, B
// row1, q1
// used
"vmov.s8 q2, #0 \n\t" // q2 used
"vdup.s8 d6, d0[0] \n\t"
"vdup.s8 d7, d1[0] \n\t" // q3 used
"vmlal.s8 q2, d2, d6 \n\t" // A col00 * B
// row0
"vmlal.s8 q2, d3, d7 \n\t" // A col10 * B
// row1, q3
// free
"vaddw.s16 q4, q4, d4 \n\t"
"vaddw.s16 q5, q5, d5 \n\t" // res row 0
"vmov.s8 q2, #0 \n\t"
"vdup.s8 d6, d0[1] \n\t"
"vdup.s8 d7, d1[1] \n\t"
"vmlal.s8 q2, d2, d6 \n\t"
"vmlal.s8 q2, d3, d7 \n\t"
"vaddw.s16 q6, q6, d4 \n\t"
"vaddw.s16 q7, q7, d5 \n\t" // res row 1
"vmov.s8 q2, #0 \n\t" "vmov.s8 q2, #0 \n\t"
"vdup.s8 d6, d0[2] \n\t"
"vdup.s8 d7, d1[2] \n\t"
"vmlal.s8 q2, d2, d6 \n\t"
"vmlal.s8 q2, d3, d7 \n\t"
"vaddw.s16 q8, q8, d4 \n\t"
"vaddw.s16 q9, q9, d5 \n\t" // res row 2
"vmov.s8 q2, #0 \n\t"
"vdup.s8 d6, d0[3] \n\t"
"vdup.s8 d7, d1[3] \n\t"
"vmlal.s8 q2, d2, d6 \n\t"
"vmlal.s8 q2, d3, d7 \n\t"
"vaddw.s16 q10, q10, d4 \n\t"
"vaddw.s16 q11, q11, d5 \n\t" // res row 3
"vmov.s8 q2, #0. \n\t"
"vdup.s8 d6, d0[4] \n\t" "vdup.s8 d6, d0[4] \n\t"
"vdup.s8 d7, d1[4] \n\t" "vdup.s8 d7, d1[4] \n\t"
"vmlal.s8 q2, d2, d6 \n\t" "vmlal.s8 q2, d2, d6 \n\t"
...@@ -105,11 +163,175 @@ void Gemm::AddDot6x8(int32_t k, const int8_t *a, const int8_t *b, int32_t *c, ...@@ -105,11 +163,175 @@ void Gemm::AddDot6x8(int32_t k, const int8_t *a, const int8_t *b, int32_t *c,
"vaddw.s16 q14, q14, d4 \n\t" "vaddw.s16 q14, q14, d4 \n\t"
"vaddw.s16 q15, q15, d5 \n\t" // res row 5 "vaddw.s16 q15, q15, d5 \n\t" // res row 5
"subs %[kc1], %[kc1], #1 \n\t" "vld1.s8 {d0}, [%[a_ptr]], r0 \n\t" // A col0
"vld1.s8 {d1}, [%[a_ptr]], r0 \n\t" // A col1, q0
// used
"vld1.s8 {d2-d3}, [%[b_ptr]]! \n\t" // B row0, B
// row1, q1
// used
"vmov.s8 q2, #0 \n\t" // q2 used
"vdup.s8 d6, d0[0] \n\t"
"vdup.s8 d7, d1[0] \n\t" // q3 used
"vmlal.s8 q2, d2, d6 \n\t" // A col00 * B
// row0
"vmlal.s8 q2, d3, d7 \n\t" // A col10 * B
// row1, q3
// free
"vaddw.s16 q4, q4, d4 \n\t"
"vaddw.s16 q5, q5, d5 \n\t" // res row 0
"vmov.s8 q2, #0 \n\t"
"vdup.s8 d6, d0[1] \n\t"
"vdup.s8 d7, d1[1] \n\t"
"vmlal.s8 q2, d2, d6 \n\t"
"vmlal.s8 q2, d3, d7 \n\t"
"vaddw.s16 q6, q6, d4 \n\t"
"vaddw.s16 q7, q7, d5 \n\t" // res row 1
"vmov.s8 q2, #0 \n\t"
"vdup.s8 d6, d0[2] \n\t"
"vdup.s8 d7, d1[2] \n\t"
"vmlal.s8 q2, d2, d6 \n\t"
"vmlal.s8 q2, d3, d7 \n\t"
"vaddw.s16 q8, q8, d4 \n\t"
"vaddw.s16 q9, q9, d5 \n\t" // res row 2
"vmov.s8 q2, #0 \n\t"
"vdup.s8 d6, d0[3] \n\t"
"vdup.s8 d7, d1[3] \n\t"
"vmlal.s8 q2, d2, d6 \n\t"
"vmlal.s8 q2, d3, d7 \n\t"
"vaddw.s16 q10, q10, d4 \n\t"
"vaddw.s16 q11, q11, d5 \n\t" // res row 3
"vmov.s8 q2, #0. \n\t"
"vdup.s8 d6, d0[4] \n\t"
"vdup.s8 d7, d1[4] \n\t"
"vmlal.s8 q2, d2, d6 \n\t"
"vmlal.s8 q2, d3, d7 \n\t"
"vaddw.s16 q12, q12, d4 \n\t"
"vaddw.s16 q13, q13, d5 \n\t" // res row 4
"vmov.s8 q2, #0 \n\t"
"vdup.s8 d6, d0[5] \n\t"
"vdup.s8 d7, d1[5] \n\t"
"vmlal.s8 q2, d2, d6 \n\t"
"vmlal.s8 q2, d3, d7 \n\t"
"vaddw.s16 q14, q14, d4 \n\t"
"vaddw.s16 q15, q15, d5 \n\t" // res row 5
"vld1.s8 {d0}, [%[a_ptr]], r0 \n\t" // A col0
"vld1.s8 {d1}, [%[a_ptr]], r0 \n\t" // A col1, q0
// used
"vld1.s8 {d2-d3}, [%[b_ptr]]! \n\t" // B row0, B
// row1, q1
// used
"vmov.s8 q2, #0 \n\t" // q2 used
"vdup.s8 d6, d0[0] \n\t"
"vdup.s8 d7, d1[0] \n\t" // q3 used
"vmlal.s8 q2, d2, d6 \n\t" // A col00 * B
// row0
"vmlal.s8 q2, d3, d7 \n\t" // A col10 * B
// row1, q3
// free
"vaddw.s16 q4, q4, d4 \n\t"
"vaddw.s16 q5, q5, d5 \n\t" // res row 0
"vmov.s8 q2, #0 \n\t"
"vdup.s8 d6, d0[1] \n\t"
"vdup.s8 d7, d1[1] \n\t"
"vmlal.s8 q2, d2, d6 \n\t"
"vmlal.s8 q2, d3, d7 \n\t"
"vaddw.s16 q6, q6, d4 \n\t"
"vaddw.s16 q7, q7, d5 \n\t" // res row 1
"vmov.s8 q2, #0 \n\t"
"vdup.s8 d6, d0[2] \n\t"
"vdup.s8 d7, d1[2] \n\t"
"vmlal.s8 q2, d2, d6 \n\t"
"vmlal.s8 q2, d3, d7 \n\t"
"vaddw.s16 q8, q8, d4 \n\t"
"vaddw.s16 q9, q9, d5 \n\t" // res row 2
"vmov.s8 q2, #0 \n\t"
"vdup.s8 d6, d0[3] \n\t"
"vdup.s8 d7, d1[3] \n\t"
"vmlal.s8 q2, d2, d6 \n\t"
"vmlal.s8 q2, d3, d7 \n\t"
"vaddw.s16 q10, q10, d4 \n\t"
"vaddw.s16 q11, q11, d5 \n\t" // res row 3
"vmov.s8 q2, #0. \n\t"
"vdup.s8 d6, d0[4] \n\t"
"vdup.s8 d7, d1[4] \n\t"
"vmlal.s8 q2, d2, d6 \n\t"
"vmlal.s8 q2, d3, d7 \n\t"
"vaddw.s16 q12, q12, d4 \n\t"
"vaddw.s16 q13, q13, d5 \n\t" // res row 4
"vmov.s8 q2, #0 \n\t"
"vdup.s8 d6, d0[5] \n\t"
"vdup.s8 d7, d1[5] \n\t"
"vmlal.s8 q2, d2, d6 \n\t"
"vmlal.s8 q2, d3, d7 \n\t"
"vaddw.s16 q14, q14, d4 \n\t"
"vaddw.s16 q15, q15, d5 \n\t" // res row 5
"subs %[kc1], %[kc1], #1 \n\t" // last <8 rows
"bge 0b \n\t" "bge 0b \n\t"
"1: \n\t" // odd, last row "1: \n\t"
"subs %[kc2], %[kc2], #1 \n\t" "subs %[kc3], %[kc3], #1 \n\t"
"blt 2f \n\t" "blt 3f \n\t"
"2: \n\t"
"vld1.s8 {d0}, [%[a_ptr]], r0 \n\t" // A col0
"vld1.s8 {d1}, [%[a_ptr]], r0 \n\t" // A col1, q0
// used
"vld1.s8 {d2-d3}, [%[b_ptr]]! \n\t" // B row0, B
// row1, q1
// used
"vmov.s8 q2, #0 \n\t" // q2 used
"vdup.s8 d6, d0[0] \n\t"
"vdup.s8 d7, d1[0] \n\t" // q3 used
"vmlal.s8 q2, d2, d6 \n\t" // A col00 * B
// row0
"vmlal.s8 q2, d3, d7 \n\t" // A col10 * B
// row1, q3
// free
"vaddw.s16 q4, q4, d4 \n\t"
"vaddw.s16 q5, q5, d5 \n\t" // res row 0
"vmov.s8 q2, #0 \n\t"
"vdup.s8 d6, d0[1] \n\t"
"vdup.s8 d7, d1[1] \n\t"
"vmlal.s8 q2, d2, d6 \n\t"
"vmlal.s8 q2, d3, d7 \n\t"
"vaddw.s16 q6, q6, d4 \n\t"
"vaddw.s16 q7, q7, d5 \n\t" // res row 1
"vmov.s8 q2, #0 \n\t"
"vdup.s8 d6, d0[2] \n\t"
"vdup.s8 d7, d1[2] \n\t"
"vmlal.s8 q2, d2, d6 \n\t"
"vmlal.s8 q2, d3, d7 \n\t"
"vaddw.s16 q8, q8, d4 \n\t"
"vaddw.s16 q9, q9, d5 \n\t" // res row 2
"vmov.s8 q2, #0 \n\t"
"vdup.s8 d6, d0[3] \n\t"
"vdup.s8 d7, d1[3] \n\t"
"vmlal.s8 q2, d2, d6 \n\t"
"vmlal.s8 q2, d3, d7 \n\t"
"vaddw.s16 q10, q10, d4 \n\t"
"vaddw.s16 q11, q11, d5 \n\t" // res row 3
"vmov.s8 q2, #0. \n\t"
"vdup.s8 d6, d0[4] \n\t"
"vdup.s8 d7, d1[4] \n\t"
"vmlal.s8 q2, d2, d6 \n\t"
"vmlal.s8 q2, d3, d7 \n\t"
"vaddw.s16 q12, q12, d4 \n\t"
"vaddw.s16 q13, q13, d5 \n\t" // res row 4
"vmov.s8 q2, #0 \n\t"
"vdup.s8 d6, d0[5] \n\t"
"vdup.s8 d7, d1[5] \n\t"
"vmlal.s8 q2, d2, d6 \n\t"
"vmlal.s8 q2, d3, d7 \n\t"
"vaddw.s16 q14, q14, d4 \n\t"
"vaddw.s16 q15, q15, d5 \n\t" // res row 5
"subs %[kc3], %[kc3], #1 \n\t"
"bge 2b \n\t"
"3: \n\t" // odd, last
// row
"subs %[kc4], %[kc4], #1 \n\t"
"blt 4f \n\t"
"vld1.s8 {d0}, [%[a_ptr]] \n\t" "vld1.s8 {d0}, [%[a_ptr]] \n\t"
"vld1.s8 {d1}, [%[b_ptr]] \n\t" "vld1.s8 {d1}, [%[b_ptr]] \n\t"
"vdup.s8 d2, d0[0] \n\t" "vdup.s8 d2, d0[0] \n\t"
...@@ -136,17 +358,16 @@ void Gemm::AddDot6x8(int32_t k, const int8_t *a, const int8_t *b, int32_t *c, ...@@ -136,17 +358,16 @@ void Gemm::AddDot6x8(int32_t k, const int8_t *a, const int8_t *b, int32_t *c,
"vmull.s8 q2, d1, d2 \n\t" "vmull.s8 q2, d1, d2 \n\t"
"vaddw.s16 q14, q14, d4 \n\t" "vaddw.s16 q14, q14, d4 \n\t"
"vaddw.s16 q15, q15, d5 \n\t" // res row 4 "vaddw.s16 q15, q15, d5 \n\t" // res row 4
"2: \n\t" "4: \n\t"
"vst1.32 {q4, q5}, [%[c]], %[step] \n\t" "vst1.32 {q4, q5}, [%[c]], %[step] \n\t"
"vst1.32 {q6, q7}, [%[c]], %[step] \n\t" "vst1.32 {q6, q7}, [%[c]], %[step] \n\t"
"vst1.32 {q8, q9}, [%[c]], %[step] \n\t" "vst1.32 {q8, q9}, [%[c]], %[step] \n\t"
"vst1.32 {q10, q11}, [%[c]], %[step] \n\t" "vst1.32 {q10, q11}, [%[c]], %[step] \n\t"
"vst1.32 {q12, q13}, [%[c]], %[step] \n\t" "vst1.32 {q12, q13}, [%[c]], %[step] \n\t"
"vst1.32 {q14, q15}, [%[c]] \n\t" "vst1.32 {q14, q15}, [%[c]] \n\t"
: :
: [a_ptr] "r"(a_ptr), [b_ptr] "r"(b_ptr), [c] "r"(c), [kc1] "r"(kc1), : [a_ptr] "r"(a_ptr), [b_ptr] "r"(b_ptr), [c] "r"(c), [kc1] "r"(kc1),
[kc2] "r"(kc2), [step] "r"(step) [kc3] "r"(kc3), [kc4] "r"(kc4), [step] "r"(step)
: "cc", "memory", "r0", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", : "cc", "memory", "r0", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7",
"q8", "q9", "q10", "q11", "q12", "q13", "q14", "q15"); "q8", "q9", "q10", "q11", "q12", "q13", "q14", "q15");
#endif #endif
......
...@@ -12,80 +12,80 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,80 +12,80 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "../test_helper.h"
#include "../test_include.h" #include "../test_include.h"
#include "operators/mul_op.h" #include "operators/mul_op.h"
int main() { #define a(i, j) a[(i)*lda + (j)]
paddle_mobile::Loader<paddle_mobile::CPU> loader; #define b(i, j) b[(i)*ldb + (j)]
auto program = loader.Load(g_resnet); #define c(i, j) c[(i)*ldc + (j)]
PADDLE_MOBILE_ENFORCE(program.originProgram != nullptr,
"program file read fail"); namespace paddle_mobile {
using framework::AttributeMap;
Executor4Test<paddle_mobile::CPU, using framework::DDim;
paddle_mobile::operators::MulOp<paddle_mobile::CPU, float>> using framework::Scope;
executor(program, "mul"); using framework::make_ddim;
template <typename I, typename O>
// 1. input_tensors; int TestMulOP() {
vector<Tensor> input_tensors; int32_t m = 1024;
int32_t n = 1024;
Tensor input1; int32_t k = 1024;
auto input1_data = CreateInput<float>(&input1, {3, 2, 1, 1}, 0, 1); int32_t lda = k;
input_tensors.push_back(input1); int32_t ldb = n;
Tensor input2; int32_t ldc = n;
auto input2_data = CreateInput<float>(&input2, {2, 3}, 0, 1); DDim inputA_shape = make_ddim({m, k});
input_tensors.push_back(input2); DDim inputB_shape = make_ddim({k, n});
VariableNameMap inputs;
// 2. input_names VariableNameMap outputs;
vector<string> input_names({ auto scope = std::make_shared<Scope>();
"pool2d_0.tmp_0", inputs["X"] = std::vector<std::string>({"inputA"});
"fc_0.w_0", inputs["Y"] = std::vector<std::string>({"inputB"});
}); outputs["Out"] = std::vector<std::string>({"output"});
// 3. output_names auto inputA_var = scope.get()->Var("inputA");
vector<string> output_names({"fc_0.tmp_0"}); auto inputA = inputA_var->template GetMutable<framework::LoDTensor>();
SetupTensor<I>(inputA, inputA_shape, -127, 127);
// 4. out_dims; auto inputB_var = scope.get()->Var("inputB");
vector<DDim> out_ddims; auto inputB = inputB_var->template GetMutable<framework::LoDTensor>();
auto out_ddim = paddle_mobile::framework::make_ddim({3, 3}); SetupTensor<I>(inputB, inputB_shape, -127, 127);
out_ddims.push_back(out_ddim);
auto output_var = scope.get()->Var("output");
auto output = executor.Predict<LoDTensor>(input_tensors, input_names, AttributeMap attrs;
output_names, out_ddims); attrs["x_num_col_dims"].Set<int>(1);
attrs["y_num_col_dims"].Set<int>(1);
auto output0_data = output[0]->data<float>(); auto *op =
new operators::MulOp<CPU, float>("mul", inputs, outputs, attrs, scope);
auto dim_1 = input1.numel() / input1.dims()[0]; op->InferShape();
DLOG << " input1 : "; op->Run();
for (int i = 0; i < input1.dims()[0]; ++i) { auto output = output_var->template Get<framework::LoDTensor>();
for (int j = 0; j < dim_1; ++j) { const O *output_data = output->data<O>();
DLOGF("%f ", input1_data[i * dim_1 + j]); // compare
} O *c = static_cast<O *>(memory::Alloc(sizeof(O) * m * n));
DLOGF("\n"); I *a = inputA->data<I>();
I *b = inputB->data<I>();
for (int32_t i = 0; i < m; ++i) {
for (int32_t j = 0; j < n; ++j) {
O r = 0;
for (int32_t p = 0; p < k; p++) {
r += static_cast<O>(a(i, p)) * static_cast<O>(b(p, j));
} }
c(i, j) = r;
auto dim_2 = input2.numel() / input2.dims()[0];
DLOG << " input2 : ";
for (int i = 0; i < input2.dims()[0]; ++i) {
for (int j = 0; j < dim_2; ++j) {
DLOGF("%f ", input2_data[i * dim_2 + j]);
} }
DLOGF("\n");
} }
auto dim_output0 = output[0]->numel() / output[0]->dims()[0]; for (int32_t i = 0; i < m * n; ++i) {
DLOG << " output : "; PADDLE_MOBILE_ENFORCE(
for (int i = 0; i < output[0]->dims()[0]; ++i) { output_data[i] == c[i], "output[%d] = %d, output_cmp[%d] = %d", i,
for (int j = 0; j < dim_output0; ++j) { static_cast<int32_t>(output_data[i]), i, static_cast<int32_t>(c[i]));
DLOGF("%f ", output0_data[i * dim_2 + j]);
}
DLOGF("\n");
} }
DLOG << "Run MulOp successfully!";
delete op;
return 0;
}
} // namespace paddle_mobile
/// output (3,3) int main() {
DLOG << "output memory size : " << output[0]->memory_size(); paddle_mobile::TestMulOP<int8_t, int32_t>();
DLOG << "output numel : " << output[0]->numel(); paddle_mobile::TestMulOP<float, float>();
DLOG << input1_data[0] << " x " << input2_data[0] << " + " << input1_data[1]
<< " x " << input2_data[0 + 3] << " = " << output0_data[0];
return 0; return 0;
} }
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册