未验证 提交 bfbca2c6 编写于 作者: E eclipsycn 提交者: GitHub

Merge pull request #299 from Eclipsess/develop

fix #298 optimize some ops
...@@ -71,8 +71,9 @@ void BatchNormKernel<CPU, float>::Compute(const BatchNormParam &param) const { ...@@ -71,8 +71,9 @@ void BatchNormKernel<CPU, float>::Compute(const BatchNormParam &param) const {
{ {
for (int n = 0; n < N; n++) { for (int n = 0; n < N; n++) {
for (int h = 0; h < H; h++) { for (int h = 0; h < H; h++) {
int tmp_index = n * stride0 + i * stride1 + h * stride2;
for (int w = 0; w < W; w++) { for (int w = 0; w < W; w++) {
int index = n * stride0 + i * stride1 + h * stride2 + w; int index = tmp_index + w;
out_ptr[index] = out_ptr[index] =
input_x_ptr[index] * new_scale_ptr[i] + new_bias_ptr[i]; input_x_ptr[index] * new_scale_ptr[i] + new_bias_ptr[i];
} }
......
...@@ -51,38 +51,6 @@ class ConcatFunctor { ...@@ -51,38 +51,6 @@ class ConcatFunctor {
} }
} }
}; };
template <typename T>
void StridedNumelCopyWithAxis(int64_t axis, T *dst,
const framework::DDim &dst_stride_numel,
const T *src,
const framework::DDim &src_stride_numel,
int64_t size) {
int64_t before = dst_stride_numel[0] / dst_stride_numel[axis];
int64_t src_after = src_stride_numel[axis];
int64_t dst_after = dst_stride_numel[axis];
/// "src and dst tensor should have the same dims size."
assert(src_stride_numel.size() == dst_stride_numel.size());
for (int64_t i = 0; i < axis; ++i) {
if (i < axis) {
/// src and dst should have the same elements
/// except the specified axis.
assert(src_stride_numel[i] / src_stride_numel[axis] ==
dst_stride_numel[i] / dst_stride_numel[axis]);
} else if (i == axis) {
continue;
} else {
/// "src and dst should have the same elements "
/// "except the specified axis."
assert(src_stride_numel[i] == dst_stride_numel[i]);
}
}
for (int64_t i = 0; i < before; ++i) {
memory::Copy(dst + i * dst_after, src + i * src_after, sizeof(T) * size);
}
}
template <> template <>
void ConcatKernel<CPU, float>::Compute(const ConcatParam &param) const { void ConcatKernel<CPU, float>::Compute(const ConcatParam &param) const {
...@@ -97,10 +65,13 @@ void ConcatKernel<CPU, float>::Compute(const ConcatParam &param) const { ...@@ -97,10 +65,13 @@ void ConcatKernel<CPU, float>::Compute(const ConcatParam &param) const {
for (auto *in : inputs) { for (auto *in : inputs) {
auto in_stride = framework::stride_numel(in->dims()); auto in_stride = framework::stride_numel(in->dims());
auto out_stride = framework::stride_numel(out->dims()); auto out_stride = framework::stride_numel(out->dims());
StridedNumelCopyWithAxis<float>(axis, out->data<float>() + output_offset, auto dst = out->data<float>() + output_offset;
out_stride, in->data<float>(), in_stride, auto src = in->data<float>();
in_stride[axis]); PADDLE_MOBILE_ENFORCE(
output_offset += in_stride[axis]; in_stride.size() == out_stride.size(),
"src and dst tensor should have the same dims size.");
memory::Copy(dst, src, sizeof(float) * in_stride[0]);
output_offset += in_stride[0];
} }
} else { } else {
std::vector<framework::Tensor> inputs_concat(inputs.size()); std::vector<framework::Tensor> inputs_concat(inputs.size());
......
...@@ -15,19 +15,30 @@ limitations under the License. */ ...@@ -15,19 +15,30 @@ limitations under the License. */
#pragma once #pragma once
#include "operators/kernel/relu_kernel.h" #include "operators/kernel/relu_kernel.h"
#include <operators/math/transform.h>
namespace paddle_mobile { namespace paddle_mobile {
namespace operators { namespace operators {
template <typename T>
struct ReluFunctor {
inline T operator()(T in) const { return in > 0 ? in : 0; }
};
template <> template <>
void ReluKernel<CPU, float>::Compute(const ReluParam &param) const { void ReluKernel<CPU, float>::Compute(const ReluParam &param) const {
const auto *input_x = param.InputX(); const auto *input_x = param.InputX();
auto *input_x_ptr = input_x->data<float>(); auto *input_x_ptr = input_x->data<float>();
auto *out = param.Out(); auto *out = param.Out();
auto *out_ptr = out->mutable_data<float>(); auto *out_ptr = out->mutable_data<float>();
for (int i = 0; i < input_x->numel(); i++) {
out_ptr[i] = input_x_ptr[i] > 0 ? input_x_ptr[i] : 0; ReluFunctor<float> func_;
} math::Transform trans;
trans(input_x_ptr, input_x_ptr + input_x->numel(), out_ptr, func_);
// for (int i = 0; i < input_x->numel(); i++) {
// out_ptr[i] = input_x_ptr[i] > 0 ? input_x_ptr[i] : 0;
// }
} }
} // namespace operators } // namespace operators
} // namespace paddle_mobile } // namespace paddle_mobile
...@@ -42,12 +42,13 @@ struct LRNFunctor { ...@@ -42,12 +42,13 @@ struct LRNFunctor {
for (int index = start; index < end; index++) { for (int index = start; index < end; index++) {
int channel = b + index; int channel = b + index;
if (channel >= 0 && channel < C) { if (channel >= 0 && channel < C) {
int tmp_u = a * stride0 + b * stride1;
int tmp_i = a * stride0 + channel * stride1;
for (int c = 0; c < H; c++) { for (int c = 0; c < H; c++) {
for (int d = 0; d < W; d++) { for (int d = 0; d < W; d++) {
int u = a * stride0 + b * stride1 + c * stride2 + d; int tmp = c * stride2 + d;
int u = tmp_u + tmp;
int i = a * stride0 + channel * stride1 + c * stride2 + d; int i = tmp_i + tmp;
sqr_buffer_ptr[u] += alpha * input_ptr[i] * input_ptr[i]; sqr_buffer_ptr[u] += alpha * input_ptr[i] * input_ptr[i];
} }
} }
......
...@@ -67,35 +67,6 @@ inline void trim_trailing_singular_dims(framework::DDim *dims) { ...@@ -67,35 +67,6 @@ inline void trim_trailing_singular_dims(framework::DDim *dims) {
} }
} }
template <typename T>
class RowwiseTransformIterator {
public:
RowwiseTransformIterator(const T *ptr, int n) : ptr_(ptr), i_(0), n_(n) {}
RowwiseTransformIterator<T> &operator++() {
++i_;
if (UNLIKELY(i_ == n_)) {
i_ = 0;
}
return *this;
}
bool operator==(const RowwiseTransformIterator<T> &rhs) const {
return (ptr_ + i_) == &(*rhs);
}
bool operator!=(const RowwiseTransformIterator<T> &rhs) const {
return (ptr_ + i_) != &(*rhs);
}
const T &operator*() { return ptr_[i_]; }
private:
const T *ptr_;
int i_;
int64_t n_;
};
/// (4,20,2)+(20,): (20,) just as (20,1), when move 2 strides in last /// (4,20,2)+(20,): (20,) just as (20,1), when move 2 strides in last
/// dimension /// dimension
/// in (4,20,2) is 2 , /// in (4,20,2) is 2 ,
...@@ -107,15 +78,23 @@ class MidWiseTransformIterator { ...@@ -107,15 +78,23 @@ class MidWiseTransformIterator {
: ptr_(ptr), i_(0), j_(0), n_(n), post_(post) {} : ptr_(ptr), i_(0), j_(0), n_(n), post_(post) {}
MidWiseTransformIterator<T> &operator++() { MidWiseTransformIterator<T> &operator++() {
++j_; if (post_ != 1) {
if (UNLIKELY(j_ == post_)) { ++j_;
if (UNLIKELY(j_ == post_)) {
++i_;
j_ = 0;
if (UNLIKELY(i_ == n_)) {
i_ = 0;
}
}
return *this;
} else {
++i_; ++i_;
j_ = 0;
if (UNLIKELY(i_ == n_)) { if (UNLIKELY(i_ == n_)) {
i_ = 0; i_ = 0;
} }
return *this;
} }
return *this;
} }
bool operator==(const MidWiseTransformIterator<T> &rhs) const { bool operator==(const MidWiseTransformIterator<T> &rhs) const {
...@@ -153,11 +132,6 @@ class TransformFunctor { ...@@ -153,11 +132,6 @@ class TransformFunctor {
trans(x_, x_ + nx_, y_, z_, func_); trans(x_, x_ + nx_, y_, z_, func_);
} }
inline void RunRowWise(int n, int pre) const {
math::Transform trans;
trans(x_, x_ + nx_, RowwiseTransformIterator<T>(y_, n), z_, func_);
}
inline void RunMidWise(int n, int pre, int post) const { inline void RunMidWise(int n, int pre, int post) const {
math::Transform trans; math::Transform trans;
trans(x_, x_ + nx_, MidWiseTransformIterator<T>(y_, n, post), z_, func_); trans(x_, x_ + nx_, MidWiseTransformIterator<T>(y_, n, post), z_, func_);
...@@ -179,31 +153,25 @@ void ElementwiseComputeEx(const framework::Tensor *x, ...@@ -179,31 +153,25 @@ void ElementwiseComputeEx(const framework::Tensor *x,
auto x_dims = x->dims(); auto x_dims = x->dims();
auto y_dims = y->dims(); auto y_dims = y->dims();
// PADDLE_ENFORCE_GE(x_dims.size(), y_dims.size(), PADDLE_MOBILE_ENFORCE(x_dims.size() >= y_dims.size(),
// "Rank of first input must >= rank of second "Rank of first input must >= rank of second input.");
// input.");
if (x_dims == y_dims) { if (x_dims == y_dims) {
functor.Run(); functor.Run();
return; return;
} }
/// axis = -1 represent the last dimension. /// axis = -1 represent the last dimensions.
axis = (axis == -1 ? x_dims.size() - y_dims.size() : axis); axis = (axis == -1 ? x_dims.size() - y_dims.size() : axis);
// PADDLE_ENFORCE(axis >= 0 && axis < x_dims.size(), PADDLE_MOBILE_ENFORCE(axis >= 0 && axis < x_dims.size(),
// "Axis should be in range [0, x_dims)"); "Axis should be in range [0, x_dims)");
trim_trailing_singular_dims(&y_dims); trim_trailing_singular_dims(&y_dims);
axis = (y_dims.size() == 0) ? x_dims.size() : axis; axis = (y_dims.size() == 0) ? x_dims.size() : axis;
int pre, n, post; int pre, n, post;
get_mid_dims(x_dims, y_dims, axis, &pre, &n, &post); get_mid_dims(x_dims, y_dims, axis, &pre, &n, &post);
if (post == 1) {
functor.RunRowWise(n, pre); functor.RunMidWise(n, pre, post);
return;
} else {
functor.RunMidWise(n, pre, post);
return;
}
} }
} // namespace operators } // namespace operators
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册