提交 33e1e2dd 编写于 作者: H hjchen2

Fix softmax

上级 9729edac
......@@ -350,7 +350,7 @@ PMStatus Executor<Device, T>::Predict() {
_tp[ops_list_[i]->Type()] += timeCost;
}
}
DLOG << "====================[ profile ]======================";
printf("====================[ profile ]======================\n");
typedef std::pair<std::string, uint64_t> prof_t;
std::vector<prof_t> _tv(_tp.begin(), _tp.end());
uint64_t _ptotal = 0;
......@@ -367,7 +367,7 @@ PMStatus Executor<Device, T>::Predict() {
static_cast<float>(p.second),
static_cast<float>(p.second) / _ptotal * 100.0);
}
DLOG << "====================[---------]======================";
printf("====================[---------]======================\n");
#endif
return PMSuccess;
}
......
......@@ -148,8 +148,8 @@ class Tensor : public TensorBase {
PADDLE_MOBILE_ENFORCE(
(std::is_same<T, void>::value ||
holder_->type().hash_code() == typeid(T).hash_code()),
"Tensor holds the wrong type, it holds %s",
this->holder_->type().name());
"Tensor holds the wrong type, it holds %s, requested %s",
this->holder_->type().name(), typeid(T).name());
return reinterpret_cast<T *>(reinterpret_cast<uintptr_t>(holder_->ptr()) +
offset_);
......@@ -162,7 +162,7 @@ class Tensor : public TensorBase {
PADDLE_MOBILE_ENFORCE(
(std::is_same<T, void>::value ||
holder_->type().hash_code() == typeid(T).hash_code()),
"Tensor holds the wrong type, it holds %s ,requested:%s",
"Tensor holds the wrong type, it holds %s, requested %s",
this->holder_->type().name(), typeid(T).name());
return reinterpret_cast<const T *>(
......
......@@ -25,12 +25,11 @@ limitations under the License. */
namespace paddle_mobile {
namespace operators {
using std::string;
template <typename DeviceType, typename T>
class FillConstantOp : public framework::OperatorBase<DeviceType> {
public:
FillConstantOp(const string &type, const VariableNameMap &inputs,
FillConstantOp(const std::string &type, const VariableNameMap &inputs,
const VariableNameMap &outputs,
const framework::AttributeMap attrs,
std::shared_ptr<framework::Scope> scope)
......@@ -58,7 +57,7 @@ class FillConstantOp : public framework::OperatorBase<DeviceType> {
tensor->Resize(framework::make_ddim(param_.Shape()));
tensor->mutable_data(framework::ToTypeIndex(data_type));
math::set_constant(tensor, value);
math::SetConstant(tensor, value);
}
void Init() {}
......
......@@ -13,8 +13,9 @@ See the License for the specific language governing permissions and
limitations under the License. */
#ifdef FUSION_CONVADDADDPRELU_OP
#pragma once
#include <string>
#include <vector>
#include "operators/math/conv_func.h"
#include "operators/math/im2col.h"
......@@ -115,20 +116,7 @@ void ConvAddAddPReluCompute(const FusionConvAddAddPReluParam<CPU> &param) {
Tensor filter_slice = filter.Slice(g * out_step, (g + 1) * out_step);
Tensor bias1_slice = bias1_batch.Slice(g * out_step, (g + 1) * out_step);
float *biase_data1 = bias1_slice.data<float>();
// int n = bias1_slice.dims()[0];
// int m = bias1_slice.dims()[1];
// for(int i=0;i<n*m;i++){
// if(biase_data1[i]!=0)
// DLOG<<biase_data1[i]<<",yangfei";
// }
// math::matmul<float>(filter_slice, false, col_matrix,
// false,
// static_cast<float>(1),
// &out_slice,
// static_cast<float>(1), true,
// biase_data);
math::matmulWithPRelu(filter_slice, false, col_matrix, false, &out_slice,
math::MatMulWithPRelu(filter_slice, false, col_matrix, false, &out_slice,
p, mode, biase_data, biase_data1);
}
}
......@@ -137,4 +125,4 @@ void ConvAddAddPReluCompute(const FusionConvAddAddPReluParam<CPU> &param) {
} // namespace operators
} // namespace paddle_mobile
#endif
#endif // FUSION_CONVADDADDPRELU_OP
......@@ -107,7 +107,7 @@ void ConvAddBasic(const FusionConvAddParam<CPU> &param) {
// gemm
Tensor out_slice = out_batch.Slice(g * out_step, (g + 1) * out_step);
Tensor filter_slice = filter.Slice(g * out_step, (g + 1) * out_step);
math::matmul<float, float>(filter_slice, false, col_matrix, false,
math::MatMul<float, float>(filter_slice, false, col_matrix, false,
static_cast<float>(1), &out_slice,
static_cast<float>(1), false, biase_data);
}
......
......@@ -25,6 +25,7 @@ limitations under the License. */
namespace paddle_mobile {
namespace operators {
void ConvAddBNReluBasic(const FusionConvAddBNReluParam<CPU> &param) {
const Tensor *input = param.Input();
Tensor filter = *param.Filter();
......@@ -105,12 +106,13 @@ void ConvAddBNReluBasic(const FusionConvAddBNReluParam<CPU> &param) {
Tensor out_slice = out_batch.Slice(g * out_step, (g + 1) * out_step);
Tensor filter_slice = filter.Slice(g * out_step, (g + 1) * out_step);
math::matmulWithBn<float>(
filter_slice, false, col_matrix, false, static_cast<float>(1),
&out_slice, static_cast<float>(0), true, &new_scale, &new_bias, g);
math::MatMulWithBn(filter_slice, false, col_matrix, false,
static_cast<float>(1), &out_slice,
static_cast<float>(0), true, &new_scale, &new_bias, g);
}
}
}
template <typename P>
void ConvAddBNReluCompute(const FusionConvAddBNReluParam<CPU> &param) {
Tensor Bias;
......@@ -126,9 +128,6 @@ void ConvAddBNReluCompute(const FusionConvAddBNReluParam<CPU> &param) {
param.Input()->dims()[1] == param.Output()->dims()[1] &&
param.Filter()->dims()[2] == param.Filter()->dims()[3] &&
param.Filter()->dims()[2] == 3 && param.Strides()[0] == 2) {
// math::DepthwiseConvAddBNRelu3x3s2p1(param.Input(), param.Filter(),
// param.Output(), param.NewScale(),
// param.NewBias(), 1);
math::DepthwiseConvAddBNRelu3x3s2p1v2(param.Input(), param.Filter(),
param.Output(), param.NewScale(),
param.NewBias(), true);
......
......@@ -13,8 +13,9 @@ See the License for the specific language governing permissions and
limitations under the License. */
#ifdef FUSION_CONVADDPRELU_OP
#pragma once
#include <string>
#include <vector>
#include "operators/math/conv_func.h"
#include "operators/math/im2col.h"
......@@ -30,8 +31,6 @@ void ConvAddPReluCompute(const FusionConvAddPReluParam<CPU> &param) {
const Tensor *input = param.Input();
Tensor filter = *param.Filter();
Tensor bias = *param.Bias();
// DLOG<<"yangfei";
// DLOG<<bias.dims();
int axis = param.Axis();
Tensor *output = param.Output();
float *biase_data = bias.data<float>();
......@@ -112,13 +111,7 @@ void ConvAddPReluCompute(const FusionConvAddPReluParam<CPU> &param) {
// gemm
Tensor out_slice = out_batch.Slice(g * out_step, (g + 1) * out_step);
Tensor filter_slice = filter.Slice(g * out_step, (g + 1) * out_step);
// math::matmul<float>(filter_slice, false, col_matrix,
// false,
// static_cast<float>(1),
// &out_slice,
// static_cast<float>(1), true,
// biase_data);
math::matmulWithPRelu(filter_slice, false, col_matrix, false, &out_slice,
math::MatMulWithPRelu(filter_slice, false, col_matrix, false, &out_slice,
p, mode, biase_data, nullptr);
}
}
......@@ -127,4 +120,4 @@ void ConvAddPReluCompute(const FusionConvAddPReluParam<CPU> &param) {
} // namespace operators
} // namespace paddle_mobile
#endif
#endif // FUSION_CONVADDPRELU_OP
......@@ -112,7 +112,7 @@ void ConvAddReluCompute(const FusionConvAddReluParam<CPU> &param) {
Tensor out_slice = out_batch.Slice(g * out_step, (g + 1) * out_step);
Tensor filter_slice = filter.Slice(g * out_step, (g + 1) * out_step);
math::matmul<Itype, Otype>(filter_slice, false, col_matrix, false, alpha,
math::MatMul<Itype, Otype>(filter_slice, false, col_matrix, false, alpha,
&out_slice, beta, true, bias_data);
}
}
......
......@@ -106,7 +106,7 @@ inline void GemmConv(const ConvParam<CPU> &param) {
// gemm
Tensor out_slice = out_batch.Slice(g * out_step, (g + 1) * out_step);
Tensor filter_slice = filter.Slice(g * out_step, (g + 1) * out_step);
math::matmul<Itype, Otype>(filter_slice, false, col_matrix, false,
math::MatMul<Itype, Otype>(filter_slice, false, col_matrix, false,
static_cast<float>(1), &out_slice,
static_cast<float>(0), false,
static_cast<Otype *>(nullptr));
......
......@@ -108,10 +108,10 @@ void ConvBNAddReluBasic(const FusionConvBNAddReluParam<CPU> &param) {
Tensor out_slice = out_batch.Slice(g * out_step, (g + 1) * out_step);
Tensor filter_slice = filter.Slice(g * out_step, (g + 1) * out_step);
Tensor bias_data = bias_batch.Slice(g * out_step, (g + 1) * out_step);
math::matmulWithBn<float>(filter_slice, false, col_matrix, false,
static_cast<float>(1), &out_slice,
static_cast<float>(1), true, &new_scale,
&new_bias, g, bias_data.data<float>());
math::MatMulWithBn(filter_slice, false, col_matrix, false,
static_cast<float>(1), &out_slice,
static_cast<float>(1), true, &new_scale, &new_bias, g,
bias_data.data<float>());
}
}
}
......
......@@ -107,9 +107,9 @@ void ConvBNReluBasic(const FusionConvBNReluParam<CPU> &param) {
Tensor out_slice = out_batch.Slice(g * out_step, (g + 1) * out_step);
Tensor filter_slice = filter.Slice(g * out_step, (g + 1) * out_step);
math::matmulWithBn<float>(
filter_slice, false, col_matrix, false, static_cast<float>(1),
&out_slice, static_cast<float>(0), true, &new_scale, &new_bias, g);
math::MatMulWithBn(filter_slice, false, col_matrix, false,
static_cast<float>(1), &out_slice,
static_cast<float>(0), true, &new_scale, &new_bias, g);
}
}
}
......
......@@ -93,7 +93,7 @@ void ConvTransposeCompute(const ConvTransposeParam<CPU> &param) {
Tensor filter_slice = filter.Slice(g * in_step, (g + 1) * in_step);
Tensor out_slice = output_batch.Slice(g * out_step, (g + 1) * out_step);
math::matmul<P, P>(filter_slice, true, in_slice, false,
math::MatMul<P, P>(filter_slice, true, in_slice, false,
static_cast<P>(1.0), &col_matrix, static_cast<P>(0.0));
if (data_dim == 2U) {
col2im(col, dilations, strides,
......
......@@ -106,9 +106,9 @@ void DWConvBNReluBasic(const FusionDWConvBNReluParam<CPU> &param) {
// gemm
Tensor out_slice = out_batch.Slice(g * out_step, (g + 1) * out_step);
Tensor filter_slice = filter.Slice(g * out_step, (g + 1) * out_step);
math::matmulWithBn<float>(
filter_slice, false, col_matrix, false, static_cast<float>(1),
&out_slice, static_cast<float>(0), true, &new_scale, &new_bias, g);
math::MatMulWithBn(filter_slice, false, col_matrix, false,
static_cast<float>(1), &out_slice,
static_cast<float>(0), true, &new_scale, &new_bias, g);
}
}
}
......
......@@ -57,7 +57,7 @@ void FusionFcCompute(const FusionFcParam<CPU> &param) {
for (int i = 0; i < out_dim[0]; i++) {
memory::Copy(out_data + i * classes, input_z_data, sizeof(Otype) * classes);
}
math::matmul<Itype, Otype>(x_matrix, false, y_matrix, false,
math::MatMul<Itype, Otype>(x_matrix, false, y_matrix, false,
static_cast<float>(1), out, static_cast<float>(1),
false);
}
......
......@@ -25,18 +25,16 @@ limitations under the License. */
namespace paddle_mobile {
namespace operators {
using LoDTensor = framework::LoDTensor;
using Tensor = framework::Tensor;
template <typename DeviceType, typename T>
template <typename Device, typename T>
inline void ReorderInitState(const framework::Tensor& src,
std::vector<size_t> index_lod,
framework::Tensor* dst, bool indexed_src) {
math::CopyMatrixRowsFunctor<DeviceType, T> row_shuffle;
math::CopyMatrixRowsFunctor<Device, T> row_shuffle;
dst->mutable_data<T>(src.dims());
row_shuffle(src, index_lod, dst, indexed_src);
}
template <typename P>
template <typename T>
void GruCompute(const GruParam<CPU>& param) {
auto* input = param.InputInput();
auto* h0 = param.InputH0();
......@@ -57,8 +55,6 @@ void GruCompute(const GruParam<CPU>& param) {
bool is_reverse = param.IsReverse();
math::LoDTensor2BatchFunctor<CPU, float> to_batch;
to_batch(*input, batch_gate, true, is_reverse);
// math::ClearTensor<CPU, float> clearTensor;
// clearTensor(batch_gate);
if (bias) {
math::RowwiseAdd<CPU, float> add_bias;
add_bias(*batch_gate, *bias, batch_gate);
......@@ -68,7 +64,7 @@ void GruCompute(const GruParam<CPU>& param) {
gru_value.gate_weight = const_cast<float*>(weight_data);
gru_value.state_weight =
const_cast<float*>(weight_data + 2 * frame_size * frame_size);
Tensor ordered_h0;
framework::Tensor ordered_h0;
std::vector<size_t> order(batch_gate->lod()[2]);
if (h0) {
// Since the batch computing for GRU reorders the input sequences
......@@ -87,9 +83,10 @@ void GruCompute(const GruParam<CPU>& param) {
int bstart = static_cast<int>(batch_starts[n]);
int bend = static_cast<int>(batch_starts[n + 1]);
int cur_batch_size = bend - bstart;
Tensor gate_t = batch_gate->Slice(bstart, bend); // BUG
Tensor reset_hidden_prev_t = batch_reset_hidden_prev->Slice(bstart, bend);
Tensor hidden_t = batch_hidden->Slice(bstart, bend);
framework::Tensor gate_t = batch_gate->Slice(bstart, bend);
framework::Tensor reset_hidden_prev_t =
batch_reset_hidden_prev->Slice(bstart, bend);
framework::Tensor hidden_t = batch_hidden->Slice(bstart, bend);
gru_value.output_value = hidden_t.data<float>();
gru_value.gate_value = gate_t.data<float>();
gru_value.reset_output_value = reset_hidden_prev_t.data<float>();
......@@ -105,7 +102,6 @@ void GruCompute(const GruParam<CPU>& param) {
}
} // namespace operators
} // namespace paddle_mobile
#endif
#endif // GRU_OP
......@@ -19,40 +19,6 @@ limitations under the License. */
namespace paddle_mobile {
namespace operators {
// 1、如果x,y维度都是2维,
// x = [[1,2], y = [[5,6],
// [3,4]] [7,8]]
// 运算结果为正常矩阵相乘。结果 out =
// [[1*5+2*7,1*6+2*8],[3*5+4*7, 3*6+4*8]]
//
// 2、如果x的维度大于2或者y的维度大于2,x的维度(2,3,4) ,y的维度(4,1,2)
// x = [[[1,2,3,4],
// [2,3,4,5],
// [3,4,5,6]],
// [[1,2,3,4],
// [2,3,4,5],
// [3,4,5,6]]]
// y = [[[1,2]],
// [[3,4]],
// [[5,6]],
// [[7,8]]]
// 需要借助x_num_col_dims和y_num_col_dims将x和y的维度转换为2维
// 从模型中读到参数,x_num_col_dims = 2,y_num_col_dims = 1,左开右闭
// (1) 将x = (2,3,4)的index [0,x_num_col_dims)部分2,3相乘,得到6,
// [x_num_col_dims,xdim.size())部分4相乘,得到4,
// 将Tensor x的dims重写成(6,4)
// (2) 将y = (4,1,2)的index [0,y_num_col_dims)部分4相乘,得到4,
// [y_num_col_dims,ydim.size())部分1,2相乘,得到2,
// 将Tensor y的dims重写成(4,2)
// 并不影响x,y在内存中的分布。
// x = [[1,2,3,4], y = [[1,2],
// [2,3,4,5], [3,4],
// [3,4,5,6], 矩阵乘法 [5,6],
// [1,2,3,4], [7,8]]
// [2,3,4,5],
// [3,4,5,6]]
// 结果x(6行4列)乘y(4行2列),按1中矩阵相乘,结果out(6行2列)
template <typename P>
void MulCompute(const MulParam<CPU> &param) {
const Tensor *input_x = param.InputX();
......@@ -73,12 +39,12 @@ void MulCompute(const MulParam<CPU> &param) {
}
if (param.InputX()->type() == typeid(int8_t)) {
out->mutable_data<int32_t>();
math::matmul<int8_t, int32_t>(x_matrix, false, y_matrix, false,
math::MatMul<int8_t, int32_t>(x_matrix, false, y_matrix, false,
static_cast<float>(1), out,
static_cast<float>(0));
} else {
out->mutable_data<float>();
math::matmul<float, float>(x_matrix, false, y_matrix, false,
math::MatMul<float, float>(x_matrix, false, y_matrix, false,
static_cast<float>(1), out,
static_cast<float>(0));
}
......
......@@ -94,27 +94,19 @@ void FusionFcCompute(const FusionFcParam<GPU_CL> &param, cl_context context,
memory::Copy(out_data + i * classes, input_z_data, sizeof(float) * classes);
}
// for (int i = 0; i < out->numel(); i++) {
// DLOG << out_data[i];
// }
// bias_data的维度和out的维度一致
math::matmul<float>(x_matrix, false, y_matrix, false, static_cast<float>(1),
math::MatMul<float>(x_matrix, false, y_matrix, false, static_cast<float>(1),
out, static_cast<float>(1), false);
out_image->InitEmptyImage(context, commandQueue, out->dims());
framework::TensorToCLImage(out, out_image, context, commandQueue, kernel1);
DLOG << *out;
delete (input_x);
delete (input_y);
delete (input_z);
delete (out);
PADDLE_MOBILE_ENFORCE(out_dim.size() == 2, " out_dim.size must be 2.");
// if (out_dim.size() != 2) {
// out->Resize(out_dim);
// }
}
template <>
void FusionFcKernel<GPU_CL, float>::Compute(
const FusionFcParam<GPU_CL> &param) {
......
......@@ -61,7 +61,7 @@ void FusionFcKernel<GPU_MALI, float>::Compute(
for (int i = 0; i < out->numel(); i++) {
DLOG << out_data[i];
}
math::matmul<float>(x_matrix, false, y_matrix, false, static_cast<float>(1),
math::MatMul<float>(x_matrix, false, y_matrix, false, static_cast<float>(1),
out, static_cast<float>(1));
PADDLE_MOBILE_ENFORCE(out_dim.size() == 2, " out_dim.size must be 2.");
// if (out_dim.size() != 2) {
......
......@@ -44,7 +44,7 @@ void MulKernel<GPU_MALI, float>::Compute(const MulParam<GPU_MALI> &param) {
if (out_dim.size() != 2) {
out->Resize({x_matrix.dims()[0], y_matrix.dims()[1]});
}
math::matmul<float>(x_matrix, false, y_matrix, false, static_cast<float>(1),
math::MatMul<float>(x_matrix, false, y_matrix, false, static_cast<float>(1),
out, static_cast<float>(0));
if (out_dim.size() != 2) {
out->Resize(out_dim);
......
......@@ -38,7 +38,11 @@ limitations under the License. */
*
* (this is the zlib license)
*/
#if defined(__ARM_NEON__) || defined(__ARM_NEON)
#pragma once
#include <arm_neon.h>
#define c_inv_mant_mask ~0x7f800000u
......@@ -316,11 +320,11 @@ static inline float32x4_t cos_ps(float32x4_t x) {
static inline float32x4_t div_ps(float32x4_t a, float32x4_t b) {
float32x4_t reciprocal = vrecpeq_f32(b);
reciprocal = vmulq_f32(vrecpsq_f32(b, reciprocal), reciprocal);
// reciprocal = vmulq_f32(vrecpsq_f32(b, reciprocal), reciprocal);
return vmulq_f32(a, reciprocal);
}
static inline float32x4_t pow_ps(float32x4_t a, float32x4_t b) {
// pow(x, m) = exp(m * log(x))
return exp_ps(vmulq_f32(b, log_ps(a)));
}
#endif // __ARM_NEON__
......@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "operators/math/math_function.h"
#include <cstring>
#include <string>
#include "common/enforce.h"
#include "framework/data_type.h"
#include "framework/tensor.h"
#include "operators/math/gemm.h"
......@@ -35,13 +35,13 @@ struct TensorSetConstant {
float value_;
};
void set_constant(framework::Tensor *tensor, float value) {
void SetConstant(framework::Tensor *tensor, float value) {
framework::VisitDataType(framework::ToDataType(tensor->type()),
TensorSetConstant(tensor, value));
}
template <>
void matmul<float, float>(const framework::Tensor &matrix_a, bool trans_a,
void MatMul<float, float>(const framework::Tensor &matrix_a, bool trans_a,
const framework::Tensor &matrix_b, bool trans_b,
float alpha, framework::Tensor *matrix_out,
float beta, bool relu, float *bias) {
......@@ -50,7 +50,7 @@ void matmul<float, float>(const framework::Tensor &matrix_a, bool trans_a,
auto dim_out = matrix_out->dims();
PADDLE_MOBILE_ENFORCE(
dim_a.size() == 2 && dim_b.size() == 2 && dim_out.size() == 2,
"The input and output of matmul be matrix");
"The input and output of MatMul be matrix");
int M = dim_out[0];
int N = dim_out[1];
......@@ -72,7 +72,6 @@ void matmul<float, float>(const framework::Tensor &matrix_a, bool trans_a,
}
#ifdef _OPENMP
gemm.Sgemm_omp(M, N, K, alpha, a, K, matrix_b.data<float>(), N, beta,
matrix_out->data<float>(), N, relu, bias);
#else
......@@ -92,19 +91,18 @@ void matmul<float, float>(const framework::Tensor &matrix_a, bool trans_a,
}
}
template <>
void matmulWithBn<float>(const framework::Tensor &matrix_a, bool trans_a,
const framework::Tensor &matrix_b, bool trans_b,
float alpha, framework::Tensor *matrix_out, float beta,
bool relu, framework::Tensor *new_scale,
framework::Tensor *new_bias, int group, float *bias) {
void MatMulWithBn(const framework::Tensor &matrix_a, bool trans_a,
const framework::Tensor &matrix_b, bool trans_b, float alpha,
framework::Tensor *matrix_out, float beta, bool relu,
framework::Tensor *new_scale, framework::Tensor *new_bias,
int group, float *bias) {
Gemm gemm;
auto dim_a = matrix_a.dims();
auto dim_b = matrix_b.dims();
auto dim_out = matrix_out->dims();
PADDLE_MOBILE_ENFORCE(
dim_a.size() == 2 && dim_b.size() == 2 && dim_out.size() == 2,
"The input and output of matmul be matrix");
"The input and output of MatMul be matrix");
int M = dim_out[0];
int N = dim_out[1];
......@@ -122,7 +120,7 @@ void matmulWithBn<float>(const framework::Tensor &matrix_a, bool trans_a,
new_bias->data<float>() + group, bias);
#endif
}
void matmulWithPRelu(const framework::Tensor &matrix_a, bool trans_a,
void MatMulWithPRelu(const framework::Tensor &matrix_a, bool trans_a,
const framework::Tensor &matrix_b, bool trans_b,
framework::Tensor *matrix_out, float *p, std::string mode,
float *bias, float *bias1) {
......@@ -132,7 +130,7 @@ void matmulWithPRelu(const framework::Tensor &matrix_a, bool trans_a,
auto dim_out = matrix_out->dims();
PADDLE_MOBILE_ENFORCE(
dim_a.size() == 2 && dim_b.size() == 2 && dim_out.size() == 2,
"The input and output of matmul be matrix");
"The input and output of MatMul be matrix");
int M = dim_out[0];
int N = dim_out[1];
......@@ -146,7 +144,6 @@ void matmulWithPRelu(const framework::Tensor &matrix_a, bool trans_a,
gemm.SgemmWithPRelu(M, N, K, matrix_a.data<float>(), K,
matrix_b.data<float>(), N, matrix_out->data<float>(), N,
p, mode, bias, bias1);
#endif
}
......
......@@ -14,7 +14,6 @@ limitations under the License. */
#pragma once
#include <cmath>
#include <string>
#include "framework/tensor.h"
......@@ -22,37 +21,37 @@ namespace paddle_mobile {
namespace operators {
namespace math {
void set_constant(framework::Tensor *tensor, float value);
void SetConstant(framework::Tensor *tensor, float value);
template <typename Itype, typename Otype>
void matmul(const framework::Tensor &matrix_a, bool trans_a,
void MatMul(const framework::Tensor &matrix_a, bool trans_a,
const framework::Tensor &matrix_b, bool trans_b, float alpha,
framework::Tensor *matrix_out, float beta, bool relu = false,
Otype *bias = nullptr);
template <typename Itype, typename Otype>
void matmul(const framework::Tensor &matrix_a, bool trans_a,
void MatMul(const framework::Tensor &matrix_a, bool trans_a,
const framework::Tensor &matrix_b, bool trans_b, float alpha,
framework::Tensor *matrix_out, float beta, bool relu, Otype *bias,
bool addOnRow);
template <typename T>
void matmulWithBn(const framework::Tensor &matrix_a, bool trans_a,
void MatMulWithBn(const framework::Tensor &matrix_a, bool trans_a,
const framework::Tensor &matrix_b, bool trans_b, float alpha,
framework::Tensor *matrix_out, float beta, bool relu,
framework::Tensor *new_scale, framework::Tensor *new_bias,
int group, T *bias = nullptr);
int group, float *bias = nullptr);
void matmulWithPRelu(const framework::Tensor &matrix_a, bool trans_a,
void MatMulWithPRelu(const framework::Tensor &matrix_a, bool trans_a,
const framework::Tensor &matrix_b, bool trans_b,
framework::Tensor *matrix_out, float *p, std::string mode,
float *bias, float *bias1);
template <typename DeviceType, typename T>
template <typename Device, typename T>
struct ClearTensor {
void operator()(framework::Tensor *tensor);
};
template <typename DeviceType, typename T>
template <typename Device, typename T>
struct RowwiseAdd {
void operator()(const framework::Tensor &input, const framework::Tensor &vec,
framework::Tensor *output);
......
......@@ -22,7 +22,7 @@ namespace operators {
namespace math {
template <>
void matmul<int8_t, int32_t>(const framework::Tensor &matrix_a, bool trans_a,
void MatMul<int8_t, int32_t>(const framework::Tensor &matrix_a, bool trans_a,
const framework::Tensor &matrix_b, bool trans_b,
float alpha, framework::Tensor *matrix_out,
float beta, bool relu, int32_t *bias,
......@@ -32,7 +32,7 @@ void matmul<int8_t, int32_t>(const framework::Tensor &matrix_a, bool trans_a,
auto dim_out = matrix_out->dims();
PADDLE_MOBILE_ENFORCE(
dim_a.size() == 2 && dim_b.size() == 2 && dim_out.size() == 2,
"The input and output of matmul be matrix");
"The input and output of MatMul be matrix");
int32_t M = dim_out[0];
int32_t N = dim_out[1];
......@@ -96,11 +96,11 @@ void matmul<int8_t, int32_t>(const framework::Tensor &matrix_a, bool trans_a,
}
template <>
void matmul<int8_t, int32_t>(const framework::Tensor &matrix_a, bool trans_a,
void MatMul<int8_t, int32_t>(const framework::Tensor &matrix_a, bool trans_a,
const framework::Tensor &matrix_b, bool trans_b,
float alpha, framework::Tensor *matrix_out,
float beta, bool relu, int32_t *bias) {
matmul<int8_t, int32_t>(matrix_a, trans_a, matrix_b, trans_b, alpha,
MatMul<int8_t, int32_t>(matrix_a, trans_a, matrix_b, trans_b, alpha,
matrix_out, beta, relu, bias, false);
}
......
......@@ -15,154 +15,131 @@ limitations under the License. */
#ifdef SOFTMAX_OP
#include "operators/math/softmax.h"
#include "common/types.h"
#ifdef __ARM_NEON
#include <math.h>
#include <algorithm>
#include <limits>
#include "common/types.h"
#include "operators/math/math_func_neon.h"
#endif
namespace paddle_mobile {
namespace operators {
namespace math {
using framework::DDim;
using framework::Tensor;
template <typename T>
class SoftmaxFuntor<CPU, T> {
#ifdef __ARM_NEON
void sum(float *input, float *sumptr, int inner_size, int outter_size) {
float32x4_t acc = vdupq_n_f32(0);
float sum_ = 0;
for (int i = 0; i < outter_size; ++i) {
float *input_outer_ptr = input + i * inner_size;
int nn = inner_size >> 2;
int left = inner_size - (nn << 2);
for (; nn > 0; nn--) {
float32x4_t vec_input = vld1q_f32(input_outer_ptr);
acc = vaddq_f32(acc, vec_input);
input_outer_ptr += 4;
}
float32x2_t vsum_ = vadd_f32(vget_high_f32(acc), vget_low_f32(acc));
sum_ = vget_lane_f32(vsum_, 0) + vget_lane_f32(vsum_, 1);
for (; left > 0; left--) {
sum_ += *input_outer_ptr;
input_outer_ptr++;
}
}
for (int j = 0; j < inner_size * outter_size; ++j) {
sumptr[j] = sum_;
}
#if defined(__ARM_NEON) || defined(__ARM_NEON__)
#ifndef __aarch64__
inline float32_t vmaxvq_f32(const float32x4_t &r) {
float32x2_t v = vmax_f32(vget_high_f32(r), vget_low_f32(r));
return vget_lane_f32(vpmax_f32(v, v), 0);
}
inline float32_t vaddvq_f32(const float32x4_t &r) {
float32x2_t v = vadd_f32(vget_high_f32(r), vget_low_f32(r));
return vget_lane_f32(vpadd_f32(v, v), 0);
}
#endif // __aarch64__
#endif // __ARM_NEON__
float find_max(const float *input, const int num_classes) {
int remain = num_classes;
float max = -std::numeric_limits<float>::max();
#if defined(__ARM_NEON) || defined(__ARM_NEON__)
int loop = num_classes >> 3;
remain = num_classes & 0x7;
float32x4_t __max = vdupq_n_f32(max);
for (int i = 0; i < loop; ++i, input += 8) {
float32x4_t x0 = vld1q_f32(input);
float32x4_t x1 = vld1q_f32(input + 4);
__max = vmaxq_f32(x0, __max);
__max = vmaxq_f32(x1, __max);
}
max = vmaxvq_f32(__max);
#endif
for (int i = 0; i < remain; ++i) {
max = std::max(max, input[i]);
}
return max;
}
void SoftmaxCacl(const Tensor *X, Tensor *Y) {
const float *input = X->data<float>();
const DDim &dDim = X->dims();
int axis_index = 1;
if (dDim.size() < 4) {
axis_index = 0;
}
DDim outer_ddim =
paddle_mobile::framework::slice_ddim(dDim, 0, axis_index + 1);
DDim inner_ddim =
paddle_mobile::framework::slice_ddim(dDim, axis_index + 1, dDim.size());
int out_size = paddle_mobile::framework::product(outer_ddim);
int inner_size = paddle_mobile::framework::product(inner_ddim);
auto *max_ptr = new float[inner_size * out_size];
// max
for (int j = 0; j < out_size; ++j) {
const float *input_outer_ptr = input + j * inner_size;
float *max_outer_ptr = max_ptr + j * inner_size;
float max_ = 0;
for (int i = 0; i < inner_size; ++i) {
const float *input_inner_ptr = input_outer_ptr + i;
max_ = std::max(max_, input_inner_ptr[0]);
}
for (int k = 0; k < inner_size; ++k) {
max_outer_ptr[k] = max_;
template <>
void SoftmaxFuntor<CPU, float>::operator()(const framework::Tensor *X,
framework::Tensor *Y) {
const framework::DDim &dims = X->dims();
int batch_size = dims[0];
int num_classes = dims[dims.size() - 1];
int channels = X->numel() / batch_size / num_classes;
const float *x = X->data<float>();
float *y = Y->mutable_data<float>();
#pragma omp parallel for collapse(2)
for (int batch = 0; batch < X->dims()[0]; ++batch) {
for (int channel = 0; channel < channels; ++channel) {
size_t offset = (batch * channels + channel) * num_classes;
const float *input = x + offset;
float *output = y + offset;
// find max
float max = find_max(input, num_classes);
// exp(x - max)
int remain = num_classes;
#if defined(__ARM_NEON) || defined(__ARM_NEON__)
int loop = num_classes >> 3;
remain = num_classes & 0x7;
float32x4_t __max = vdupq_n_f32(max);
for (int i = 0; i < loop; ++i, input += 8, output += 8) {
float32x4_t x0 = vld1q_f32(input);
float32x4_t x1 = vld1q_f32(input + 4);
x0 = vsubq_f32(x0, __max);
x1 = vsubq_f32(x1, __max);
x0 = exp_ps(x0);
x1 = exp_ps(x1);
vst1q_f32(output, x0);
vst1q_f32(output + 4, x1);
}
}
// exp(value - max)
float *exp_sub_max = new float[inner_size * out_size];
float *exp_sub_max_ptr = &exp_sub_max[0];
for (int l = 0; l < out_size; ++l) {
const float *input_outer_ptr = input + l * inner_size;
float *max_outer_ptr = max_ptr + l * inner_size;
int nn = inner_size >> 2;
int left = inner_size - (nn << 2);
for (; nn > 0; nn--) {
float32x4_t vec_input = vld1q_f32(input_outer_ptr);
float32x4_t vec_max = vld1q_f32(max_outer_ptr);
float32x4_t vec_sub = vsubq_f32(vec_input, vec_max);
float32x4_t vec_exp = exp_ps(vec_sub);
vst1q_f32(exp_sub_max_ptr, vec_exp);
input_outer_ptr += 4;
max_outer_ptr += 4;
exp_sub_max_ptr += 4;
#endif // __ARM_NEON__
for (int i = 0; i < remain; ++i) {
output[i] = std::expf(input[i] - max);
}
for (; left > 0; left--) {
*exp_sub_max_ptr = expf(*input_outer_ptr - *max_outer_ptr);
input_outer_ptr++;
max_outer_ptr++;
exp_sub_max_ptr++;
// sum(exp(x - max))
float sum = 0.f;
output = y + offset;
#if defined(__ARM_NEON) || defined(__ARM_NEON__)
float32x4_t __sum = vdupq_n_f32(0.f);
for (int i = 0; i < loop; ++i, output += 8) {
float32x4_t x0 = vld1q_f32(output);
float32x4_t x1 = vld1q_f32(output + 4);
__sum = vaddq_f32(x0, __sum);
__sum = vaddq_f32(x1, __sum);
}
}
float *sumptr = new float[inner_size * out_size];
// sum exp
sum(exp_sub_max, sumptr, inner_size, out_size);
// div
auto *out_ptr = Y->mutable_data<float>();
for (int l = 0; l < out_size; ++l) {
const float *input_outer_ptr = exp_sub_max + l * inner_size;
float *output_outer_ptr = out_ptr + l * inner_size;
float *sum_outer_ptr = sumptr + l * inner_size;
int nn = inner_size >> 2;
int left = inner_size - (nn << 2);
for (; nn > 0; nn--) {
float32x4_t vec_input = vld1q_f32(input_outer_ptr);
float32x4_t vec_sum = vld1q_f32(sum_outer_ptr);
float32x4_t vec_div = div_ps(vec_input, vec_sum);
vst1q_f32(output_outer_ptr, vec_div);
input_outer_ptr += 4;
output_outer_ptr += 4;
sum_outer_ptr += 4;
sum += vaddvq_f32(__sum);
#endif // __ARM_NEON__
for (int i = 0; i < remain; ++i) {
sum += output[i];
}
for (; left > 0; left--) {
*output_outer_ptr = (*input_outer_ptr) / (*sum_outer_ptr);
input_outer_ptr++;
output_outer_ptr++;
sum_outer_ptr++;
// exp(x - max) / sum
float inv_sum = 1.f / sum;
output = y + offset;
#if defined(__ARM_NEON) || defined(__ARM_NEON__)
float32x4_t __inv_sum = vdupq_n_f32(inv_sum);
for (int i = 0; i < loop; ++i, output += 8) {
float32x4_t x0 = vld1q_f32(output);
float32x4_t x1 = vld1q_f32(output + 4);
x0 = vmulq_f32(x0, __inv_sum);
x1 = vmulq_f32(x1, __inv_sum);
vst1q_f32(output, x0);
vst1q_f32(output + 4, x1);
}
}
}
#else
#endif // ARM_NEON
public:
void operator()(const framework::Tensor *X, framework::Tensor *Y) {
const DDim dDim = X->dims();
int dim1 = dDim[dDim.size() - 1];
int dim0 = X->numel() / dim1 / dDim[0];
framework::DDim matrix_shape = {dim0, dim1};
for (int i = 0; i < dDim[0]; ++i) {
framework::Tensor sub_X = X->Slice(i, i + 1);
framework::Tensor sub_Y = Y->Slice(i, i + 1);
sub_X.Resize(matrix_shape);
sub_Y.Resize(matrix_shape);
for (int j = 0; j < dim0; j++) {
framework::Tensor sub_x = sub_X.Slice(j, j + 1);
framework::Tensor sub_y = sub_Y.Slice(j, j + 1);
#ifdef __ARM_NEON
SoftmaxCacl(&sub_x, &sub_y);
#endif
for (int i = 0; i < remain; ++i) {
output[i] *= inv_sum;
}
}
}
};
template class SoftmaxFuntor<CPU, float>;
}
} // namespace math
} // namespace operators
} // namespace paddle_mobile
#endif
#endif // SOFTMAX_OP
......@@ -13,17 +13,21 @@ See the License for the specific language governing permissions and
limitations under the License. */
#ifdef SOFTMAX_OP
#pragma once
#include "framework/tensor.h"
namespace paddle_mobile {
namespace operators {
namespace math {
template <typename DeviceType, typename T>
template <typename Device, typename T>
class SoftmaxFuntor {
public:
void operator()(const framework::Tensor *X, framework::Tensor *Y);
};
} // namespace math
} // namespace operators
} // namespace paddle_mobile
......
......@@ -261,20 +261,17 @@ if (NOT FOUND_MATCH)
ADD_EXECUTABLE(test-inference-api framework/test_inference_api.cpp)
target_link_libraries(test-inference-api paddle-mobile)
# gen test log
# gen test
ADD_EXECUTABLE(test-optimize framework/test_optimize.cpp)
target_link_libraries(test-optimize paddle-mobile)
#gen test
ADD_EXECUTABLE(test-pool-op operators/test_pool_op.cpp test_helper.h test_include.h executor_for_test.h)
target_link_libraries(test-pool-op paddle-mobile)
#gen test
ADD_EXECUTABLE(test-softmax operators/test_softmax_op.cpp test_helper.h test_include.h executor_for_test.h)
target_link_libraries(test-softmax paddle-mobile)
ADD_EXECUTABLE(test-softmax-op operators/test_softmax_op.cpp test_helper.h test_include.h executor_for_test.h)
target_link_libraries(test-softmax-op paddle-mobile)
# gen test
ADD_EXECUTABLE(test-gemm-accuracy common/test_gemm_accuracy.cpp)
......
......@@ -73,14 +73,14 @@ int main() {
// float
// warm-up 10 times
for (int j = 0; j < 10; ++j) {
paddle_mobile::operators::math::matmul<float, float>(
paddle_mobile::operators::math::MatMul<float, float>(
aa, false, bb, false, static_cast<float>(1), &cc, static_cast<float>(0),
false, nullptr);
}
auto time_start0 = time();
for (int j = 0; j < 10; ++j) {
paddle_mobile::operators::math::matmul<float, float>(
paddle_mobile::operators::math::MatMul<float, float>(
aa, false, bb, false, static_cast<float>(1), &cc, static_cast<float>(0),
false, nullptr);
}
......@@ -91,14 +91,14 @@ int main() {
// int8_t without bias
// warm-up 10 times
for (int j = 0; j < 10; ++j) {
paddle_mobile::operators::math::matmul<int8_t, int32_t>(
paddle_mobile::operators::math::MatMul<int8_t, int32_t>(
aa_int8, false, bb_int8, false, static_cast<float>(1), &cc_int32,
static_cast<float>(0));
}
auto time_start1 = time();
for (int j = 0; j < 10; ++j) {
paddle_mobile::operators::math::matmul<int8_t, int32_t>(
paddle_mobile::operators::math::MatMul<int8_t, int32_t>(
aa_int8, false, bb_int8, false, static_cast<float>(1), &cc_int32,
static_cast<float>(0));
}
......@@ -109,13 +109,13 @@ int main() {
// int8_t with bias, column element wise add
// warm-up 10 times
for (int j = 0; j < 10; ++j) {
paddle_mobile::operators::math::matmul<int8_t, int32_t>(
paddle_mobile::operators::math::MatMul<int8_t, int32_t>(
aa_int8, false, bb_int8, false, static_cast<float>(0.618), &cc_int8,
static_cast<float>(0), false, bias_data_col, false);
}
auto time_start2 = time();
for (int j = 0; j < 10; ++j) {
paddle_mobile::operators::math::matmul<int8_t, int32_t>(
paddle_mobile::operators::math::MatMul<int8_t, int32_t>(
aa_int8, false, bb_int8, false, static_cast<float>(0.618), &cc_int8,
static_cast<float>(0), false, bias_data_col, false);
}
......@@ -126,13 +126,13 @@ int main() {
// int8_t with bias, row element wise add
// warm-up 10 times
for (int j = 0; j < 10; ++j) {
paddle_mobile::operators::math::matmul<int8_t, int32_t>(
paddle_mobile::operators::math::MatMul<int8_t, int32_t>(
aa_int8, false, bb_int8, false, static_cast<float>(0.618), &cc_int8,
static_cast<float>(0), false, bias_data_row, true);
}
auto time_start3 = time();
for (int j = 0; j < 10; ++j) {
paddle_mobile::operators::math::matmul<int8_t, int32_t>(
paddle_mobile::operators::math::MatMul<int8_t, int32_t>(
aa_int8, false, bb_int8, false, static_cast<float>(0.618), &cc_int8,
static_cast<float>(0), false, bias_data_row, true);
}
......@@ -143,13 +143,13 @@ int main() {
// int8_t with bias&relu
// warm-up 10 times
for (int j = 0; j < 10; ++j) {
paddle_mobile::operators::math::matmul<int8_t, int32_t>(
paddle_mobile::operators::math::MatMul<int8_t, int32_t>(
aa_int8, false, bb_int8, false, static_cast<float>(0.618), &cc_int8,
static_cast<float>(0), true, bias_data_col, false);
}
auto time_start4 = time();
for (int j = 0; j < 10; ++j) {
paddle_mobile::operators::math::matmul<int8_t, int32_t>(
paddle_mobile::operators::math::MatMul<int8_t, int32_t>(
aa_int8, false, bb_int8, false, static_cast<float>(0.618), &cc_int8,
static_cast<float>(0), true, bias_data_col, false);
}
......
......@@ -28,6 +28,7 @@ void load_images(const char *image_dir, const char *images_list,
image_shapes->push_back(std::make_pair(height, width));
image_names->push_back(filename);
}
if_list.close();
}
int main(int argc, char **argv) {
......@@ -53,7 +54,7 @@ int main(int argc, char **argv) {
for (int i = 0; i < image_names.size(); i++) {
std::string file_name = image_names[i];
std::vector<float> input;
std::vector<float> input_vec;
std::vector<int64_t> dims{1, 1, 48, 512};
dims[2] = image_shapes[i].first;
dims[3] = image_shapes[i].second;
......@@ -62,14 +63,22 @@ int main(int argc, char **argv) {
std::cerr << "img_path: " << img_path << std::endl;
std::cerr << "shape = [" << dims[0] << ", " << dims[1] << ", " << dims[2]
<< ", " << dims[3] << "]" << std::endl;
GetInput<float>(img_path, &input, dims);
GetInput<float>(img_path, &input_vec, dims);
framework::Tensor input(input_vec, framework::make_ddim(dims));
// predict
auto output = paddle_mobile.Predict(input, dims);
paddle_mobile.Predict(input);
auto output_topk = paddle_mobile.Fetch("top_k_1.tmp_0");
auto output_indices = paddle_mobile.Fetch("cast_68.tmp_0");
// print result
std::cerr << file_name << std::endl;
std::cerr << output[0];
for (int j = 1; j < output.size(); ++j) {
std::cerr << " " << output[j];
std::cerr << output_topk->data<float>()[0];
for (int j = 1; j < output_topk->numel(); ++j) {
std::cerr << " " << output_topk->data<float>()[j];
}
std::cerr << std::endl;
std::cerr << output_indices->data<float>()[0];
for (int j = 1; j < output_indices->numel(); ++j) {
std::cerr << " " << output_indices->data<float>()[j];
}
std::cerr << std::endl;
}
......
......@@ -12,29 +12,88 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <math.h>
#include <limits>
#include "../test_include.h"
#include "operators/softmax_op.h"
int main() {
paddle_mobile::framework::Loader<paddle_mobile::CPU> loader;
auto program = loader.Load(std::string(g_mobilenet));
if (program.originProgram == nullptr) {
DLOG << "program read file";
namespace paddle_mobile {
void Softmax(const framework::Tensor *X, framework::Tensor *Y) {
const framework::DDim &dims = X->dims();
int batch_size = dims[0];
int num_classes = dims[dims.size() - 1];
int channels = X->numel() / batch_size / num_classes;
const float *x = X->data<float>();
float *y = Y->mutable_data<float>();
for (int batch = 0; batch < batch_size; ++batch) {
for (int c = 0; c < channels; ++c) {
size_t offset = (batch * channels + c) * num_classes;
const float *input = x + offset;
float *output = y + offset;
float max = -std::numeric_limits<float>::max();
for (int j = 0; j < num_classes; ++j) {
max = (input[j] > max) ? input[j] : max;
}
float sum = 0.f;
for (int j = 0; j < num_classes; ++j) {
float tmp = std::expf(input[j] - max);
sum += tmp;
output[j] = tmp;
}
for (int j = 0; j < num_classes; ++j) {
output[j] /= sum;
}
}
}
Executor4Test<paddle_mobile::CPU,
paddle_mobile::operators::SoftmaxOp<paddle_mobile::CPU, float>>
executor(program, "softmax");
paddle_mobile::framework::Tensor input;
SetupTensor<float>(&input, {1, 1000}, static_cast<float>(0),
static_cast<float>(1));
auto out_ddim = paddle_mobile::framework::make_ddim({1, 1000});
auto output =
executor.Predict(input, "reshape_0.tmp_0", "softmax_0.tmp_0", out_ddim);
auto *output_ptr = output->data<float>();
for (int j = 0; j < output->numel(); ++j) {
DLOG << " value of output: " << output_ptr[j];
}
int TestSoftmaxOp(const std::vector<int> input_shape) {
framework::DDim dims = framework::make_ddim(input_shape);
VariableNameMap inputs;
VariableNameMap outputs;
auto scope = std::make_shared<framework::Scope>();
inputs["X"] = std::vector<std::string>({"input"});
outputs["Out"] = std::vector<std::string>({"output"});
auto input_var = scope.get()->Var("input");
auto input = input_var->template GetMutable<framework::LoDTensor>();
SetupTensor<float>(input, dims, -100.0, 100.0);
auto output_var = scope.get()->Var("output");
auto output = output_var->template Get<framework::LoDTensor>();
framework::AttributeMap attrs;
auto *op = new operators::SoftmaxOp<CPU, float>("softmax", inputs, outputs,
attrs, scope);
op->InferShape();
op->Init();
op->Run();
framework::Tensor output_cmp;
float *output_cmp_data = output_cmp.mutable_data<float>(output->dims());
Softmax(input, &output_cmp);
const float *output_data = output->data<float>();
for (int i = 0; i < output->numel(); ++i) {
float gap = output_data[i] - output_cmp_data[i];
if (std::abs(gap / (output_data[i] + 1e-5)) > 1e-3) {
LOG(kLOG_INFO) << "output_data[" << i << "] = " << output_data[i]
<< ", output_cmp_data[" << i
<< "] = " << output_cmp_data[i];
delete op;
exit(1);
}
}
delete op;
return 0;
}
} // namespace paddle_mobile
int main(int argc, char *argv[]) {
TestSoftmaxOp({128, 1000});
TestSoftmaxOp({128, 10, 1000});
return 0;
}
......@@ -5,7 +5,7 @@ TOTAL_ERRORS=0
# The trick to remove deleted files: https://stackoverflow.com/a/2413151
for file in $(git diff --cached --name-status | awk '$1 != "D" {print $2}' | \
grep -v ".pb.cpp" | grep -v ".pb.h" | grep -v ".pb-c.h" | grep -v ".pb-c.c" | \
grep -v "protobuf-c.h" | grep -v "protobuf-c.c" | grep -v "paddle_mobile_jni.cpp"); do
grep -v "protobuf-c.h" | grep -v "protobuf-c.c"); do
cpplint $file;
TOTAL_ERRORS=$(expr $TOTAL_ERRORS + $?);
done
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册