提交 a47b90d1 编写于 作者: E eclipsess

code style

上级 7c59db45
...@@ -84,8 +84,8 @@ void StridedNumelCopyWithAxis(int64_t axis, T *dst, ...@@ -84,8 +84,8 @@ void StridedNumelCopyWithAxis(int64_t axis, T *dst,
} }
} }
template <> template <>
void ConcatKernel<CPU, float>::Compute(const ConcatParam &param) const { void ConcatKernel<CPU, float>::Compute(const ConcatParam &param) const {
auto inputs = param.Inputs(); auto inputs = param.Inputs();
auto *out = param.Out(); auto *out = param.Out();
int64_t axis = param.Axis(); int64_t axis = param.Axis();
...@@ -113,7 +113,7 @@ void StridedNumelCopyWithAxis(int64_t axis, T *dst, ...@@ -113,7 +113,7 @@ void StridedNumelCopyWithAxis(int64_t axis, T *dst,
ConcatFunctor<float> concat_functor; ConcatFunctor<float> concat_functor;
concat_functor(inputs_concat, static_cast<int>(axis), out); concat_functor(inputs_concat, static_cast<int>(axis), out);
} }
} }
} // namespace operators } // namespace operators
} // namespace paddle_mobile } // namespace paddle_mobile
...@@ -14,18 +14,19 @@ limitations under the License. */ ...@@ -14,18 +14,19 @@ limitations under the License. */
#pragma once #pragma once
#include <operators/math/transform.h>
#include "operators/kernel/relu_kernel.h" #include "operators/kernel/relu_kernel.h"
#include <operators/math/transform.h>
namespace paddle_mobile { namespace paddle_mobile {
namespace operators { namespace operators {
template <typename T> template <typename T>
struct ReluFunctor { struct ReluFunctor {
inline T operator()(T in) const { return in > 0 ? in : 0; } inline T operator()(T in) const { return in > 0 ? in : 0; }
}; };
template <>
void ReluKernel<CPU, float>::Compute(const ReluParam &param) const { template <>
void ReluKernel<CPU, float>::Compute(const ReluParam &param) const {
const auto *input_x = param.InputX(); const auto *input_x = param.InputX();
auto *input_x_ptr = input_x->data<float>(); auto *input_x_ptr = input_x->data<float>();
auto *out = param.Out(); auto *out = param.Out();
...@@ -38,6 +39,6 @@ namespace operators { ...@@ -38,6 +39,6 @@ namespace operators {
// for (int i = 0; i < input_x->numel(); i++) { // for (int i = 0; i < input_x->numel(); i++) {
// out_ptr[i] = input_x_ptr[i] > 0 ? input_x_ptr[i] : 0; // out_ptr[i] = input_x_ptr[i] > 0 ? input_x_ptr[i] : 0;
// } // }
} }
} // namespace operators } // namespace operators
} // namespace paddle_mobile } // namespace paddle_mobile
...@@ -18,7 +18,7 @@ limitations under the License. */ ...@@ -18,7 +18,7 @@ limitations under the License. */
#define UNLIKELY(condition) __builtin_expect(static_cast<bool>(condition), 0) #define UNLIKELY(condition) __builtin_expect(static_cast<bool>(condition), 0)
namespace paddle_mobile { namespace paddle_mobile {
namespace operators { namespace operators {
/* /*
* Out = X ⊙ Y * Out = X ⊙ Y
...@@ -31,7 +31,7 @@ namespace paddle_mobile { ...@@ -31,7 +31,7 @@ namespace paddle_mobile {
* pre=2*3, n=4*5, post=1 * pre=2*3, n=4*5, post=1
* x.shape(6, 20, 1) * y.shape(1, 20, 1).broadcast(6, 20, 1) * x.shape(6, 20, 1) * y.shape(1, 20, 1).broadcast(6, 20, 1)
*/ */
inline void get_mid_dims(const framework::DDim &x_dims, inline void get_mid_dims(const framework::DDim &x_dims,
const framework::DDim &y_dims, const int axis, const framework::DDim &y_dims, const int axis,
int *pre, int *n, int *post) { int *pre, int *n, int *post) {
*pre = 1; *pre = 1;
...@@ -51,10 +51,10 @@ namespace paddle_mobile { ...@@ -51,10 +51,10 @@ namespace paddle_mobile {
for (int i = axis + y_dims.size(); i < x_dims.size(); ++i) { for (int i = axis + y_dims.size(); i < x_dims.size(); ++i) {
(*post) *= x_dims[i]; (*post) *= x_dims[i];
} }
} }
/// remove dims tail 1. (4,20,1,1) -> (4,20) /// remove dims tail 1. (4,20,1,1) -> (4,20)
inline void trim_trailing_singular_dims(framework::DDim *dims) { inline void trim_trailing_singular_dims(framework::DDim *dims) {
// Remove trailing dimensions of size 1 for y // Remove trailing dimensions of size 1 for y
auto actual_dims_size = dims->size(); auto actual_dims_size = dims->size();
for (; actual_dims_size != 0; --actual_dims_size) { for (; actual_dims_size != 0; --actual_dims_size) {
...@@ -65,14 +65,14 @@ namespace paddle_mobile { ...@@ -65,14 +65,14 @@ namespace paddle_mobile {
actual_dims.resize(actual_dims_size); actual_dims.resize(actual_dims_size);
*dims = framework::make_ddim(actual_dims); *dims = framework::make_ddim(actual_dims);
} }
} }
/// (4,20,2)+(20,): (20,) just as (20,1), when move 2 strides in last /// (4,20,2)+(20,): (20,) just as (20,1), when move 2 strides in last
/// dimension /// dimension
/// in (4,20,2) is 2 , /// in (4,20,2) is 2 ,
/// (20,1) move 1 stride , to fill(add) 2 element with the same number. /// (20,1) move 1 stride , to fill(add) 2 element with the same number.
template <typename T> template <typename T>
class MidWiseTransformIterator { class MidWiseTransformIterator {
public: public:
MidWiseTransformIterator(const T *ptr, int n, int post) MidWiseTransformIterator(const T *ptr, int n, int post)
: ptr_(ptr), i_(0), j_(0), n_(n), post_(post) {} : ptr_(ptr), i_(0), j_(0), n_(n), post_(post) {}
...@@ -113,10 +113,10 @@ namespace paddle_mobile { ...@@ -113,10 +113,10 @@ namespace paddle_mobile {
int64_t j_; int64_t j_;
int64_t n_; int64_t n_;
int64_t post_; int64_t post_;
}; };
template <typename Functor, typename T, typename OutType = T> template <typename Functor, typename T, typename OutType = T>
class TransformFunctor { class TransformFunctor {
public: public:
TransformFunctor(const framework::Tensor *x, const framework::Tensor *y, TransformFunctor(const framework::Tensor *x, const framework::Tensor *y,
framework::Tensor *z, Functor func) framework::Tensor *z, Functor func)
...@@ -143,10 +143,10 @@ namespace paddle_mobile { ...@@ -143,10 +143,10 @@ namespace paddle_mobile {
OutType *z_; OutType *z_;
int64_t nx_; int64_t nx_;
Functor func_; Functor func_;
}; };
template <typename Functor, typename T, typename OutType = T> template <typename Functor, typename T, typename OutType = T>
void ElementwiseComputeEx(const framework::Tensor *x, void ElementwiseComputeEx(const framework::Tensor *x,
const framework::Tensor *y, int axis, Functor func, const framework::Tensor *y, int axis, Functor func,
framework::Tensor *z) { framework::Tensor *z) {
TransformFunctor<Functor, T, OutType> functor(x, y, z, func); TransformFunctor<Functor, T, OutType> functor(x, y, z, func);
...@@ -172,7 +172,7 @@ namespace paddle_mobile { ...@@ -172,7 +172,7 @@ namespace paddle_mobile {
get_mid_dims(x_dims, y_dims, axis, &pre, &n, &post); get_mid_dims(x_dims, y_dims, axis, &pre, &n, &post);
functor.RunMidWise(n, pre, post); functor.RunMidWise(n, pre, post);
} }
} // namespace operators } // namespace operators
} // namespace paddle_mobile } // namespace paddle_mobile
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册