提交 29cd089a 编写于 作者: G guosheng

Merge branch 'develop' of https://github.com/PaddlePaddle/paddle into add-multiBatch-chunkEval

...@@ -28,6 +28,7 @@ limitations under the License. */ ...@@ -28,6 +28,7 @@ limitations under the License. */
#include "hl_top_k.h" #include "hl_top_k.h"
#include "paddle/utils/Logging.h" #include "paddle/utils/Logging.h"
#include "NEONFunctions.h"
#include "paddle/function/GemmFunctor.h" #include "paddle/function/GemmFunctor.h"
#include "paddle/utils/ThreadLocal.h" #include "paddle/utils/ThreadLocal.h"
...@@ -4165,16 +4166,36 @@ void CpuMatrix::print(std::ostream& os) const { ...@@ -4165,16 +4166,36 @@ void CpuMatrix::print(std::ostream& os) const {
void CpuMatrix::paramReluForward(Matrix& data, Matrix& W) { void CpuMatrix::paramReluForward(Matrix& data, Matrix& W) {
real* input = data.getData(); real* input = data.getData();
real* w = W.getData(); real* w = W.getData();
real* output = data_;
size_t numElements = data.getWidth(); size_t numElements = data.getWidth();
size_t numSamples = data.getHeight(); size_t numSamples = data.getHeight();
size_t paraSize = W.getHeight() * W.getWidth(); size_t paraSize = W.getHeight() * W.getWidth();
CHECK(!(numElements % paraSize)); // this check from ParameterReluLayer::init CHECK(!(numElements % paraSize)); // this check from ParameterReluLayer::init
size_t partial_sum = numElements / paraSize; size_t partial_sum = numElements / paraSize;
if (paraSize == numElements) {
for (size_t n = 0; n < numSamples * numElements; ++n) {
output[n] = input[n] > 0 ? input[n] : input[n] * w[n % numElements];
}
return;
}
#if defined(__ARM_NEON__) || defined(__ARM_NEON)
for (size_t n = 0; n < numSamples; ++n) {
for (size_t i = 0; i < paraSize; i++) {
neon::prelu(
input + i * partial_sum, w[i], output + i * partial_sum, partial_sum);
}
input = input + numElements;
output = output + numElements;
}
#else
for (size_t n = 0, k = 0; n < numSamples; ++n) { for (size_t n = 0, k = 0; n < numSamples; ++n) {
for (size_t i = 0; i < numElements; ++i, ++k) { for (size_t i = 0; i < numElements; ++i, ++k) {
data_[k] = input[k] > 0 ? input[k] : input[k] * w[i / partial_sum]; output[k] = input[k] > 0 ? input[k] : input[k] * w[i / partial_sum];
} }
} }
#endif
} }
void CpuMatrix::paramReluBackwardW(Matrix& oGrad, Matrix& data) { void CpuMatrix::paramReluBackwardW(Matrix& oGrad, Matrix& data) {
......
...@@ -49,6 +49,46 @@ void relu(const float* a, float* b, int len) { ...@@ -49,6 +49,46 @@ void relu(const float* a, float* b, int len) {
} }
} }
// b[i] = a[i] > 0.0f ? a[i] : a[i] * w
void prelu(const float* a, float w, float* b, int len) {
int offset = len % 16;
float32x4_t ma0, ma1, ma2, ma3;
float32x4_t zero = vdupq_n_f32(0.f);
float32x4_t vw = vdupq_n_f32(w);
for (int k = 0; k < len / 16; k++, a += 16, b += 16) {
ma0 = vld1q_f32(a);
ma1 = vld1q_f32(a + 4);
ma2 = vld1q_f32(a + 8);
ma3 = vld1q_f32(a + 12);
uint32x4_t flag0 = vcgtq_f32(ma0, zero);
uint32x4_t flag1 = vcgtq_f32(ma1, zero);
uint32x4_t flag2 = vcgtq_f32(ma2, zero);
uint32x4_t flag3 = vcgtq_f32(ma3, zero);
float32x4_t mul0 = vmulq_f32(ma0, vw);
float32x4_t mul1 = vmulq_f32(ma1, vw);
float32x4_t mul2 = vmulq_f32(ma2, vw);
float32x4_t mul3 = vmulq_f32(ma3, vw);
ma0 = vbslq_f32(flag0, ma0, mul0);
ma1 = vbslq_f32(flag1, ma1, mul1);
ma2 = vbslq_f32(flag2, ma2, mul2);
ma3 = vbslq_f32(flag3, ma3, mul3);
vst1q_f32(b, ma0);
vst1q_f32(b + 4, ma1);
vst1q_f32(b + 8, ma2);
vst1q_f32(b + 12, ma3);
}
for (int i = 0; i < offset; i++) {
b[i] = a[i] > 0.0f ? a[i] : a[i] * w;
}
}
} // namespace neon } // namespace neon
} // namespace paddle } // namespace paddle
......
...@@ -18,6 +18,7 @@ namespace paddle { ...@@ -18,6 +18,7 @@ namespace paddle {
namespace neon { namespace neon {
void relu(const float* a, float* b, int len); void relu(const float* a, float* b, int len);
void prelu(const float* a, float w, float* b, int len);
} // namespace neon } // namespace neon
} // namespace paddle } // namespace paddle
...@@ -243,7 +243,6 @@ __global__ void RowConvGradFilter(const T *in, const T *dout, int num_sequence, ...@@ -243,7 +243,6 @@ __global__ void RowConvGradFilter(const T *in, const T *dout, int num_sequence,
int block_x, int block_y, int block_x, int block_y,
const size_t *batch_indices, T *dfilter) { const size_t *batch_indices, T *dfilter) {
int blx = blockDim.x; int blx = blockDim.x;
int bly = blockDim.y;
int thx = threadIdx.x; int thx = threadIdx.x;
int thy = threadIdx.y; int thy = threadIdx.y;
int gx = blockIdx.x * blx; int gx = blockIdx.x * blx;
......
...@@ -1738,8 +1738,10 @@ def conv2d_transpose(input, ...@@ -1738,8 +1738,10 @@ def conv2d_transpose(input,
h_in = input.shape[2] h_in = input.shape[2]
w_in = input.shape[3] w_in = input.shape[3]
filter_size_h = output_size[0] - (h_in - 1) * stride[0] + 2 * padding[0] filter_size_h = output_size[0] - \
filter_size_w = output_size[1] - (w_in - 1) * stride[1] + 2 * padding[1] (h_in - 1) * stride[0] + 2 * padding[0]
filter_size_w = output_size[1] - \
(w_in - 1) * stride[1] + 2 * padding[1]
filter_size = [filter_size_h, filter_size_w] filter_size = [filter_size_h, filter_size_w]
elif isinstance(filter_size, int): elif isinstance(filter_size, int):
filter_size = [filter_size, filter_size] filter_size = [filter_size, filter_size]
......
...@@ -9,6 +9,7 @@ def simple_img_conv_pool(input, ...@@ -9,6 +9,7 @@ def simple_img_conv_pool(input,
pool_size, pool_size,
pool_stride, pool_stride,
act, act,
param_attr=None,
pool_type='max', pool_type='max',
main_program=None, main_program=None,
startup_program=None): startup_program=None):
...@@ -16,6 +17,7 @@ def simple_img_conv_pool(input, ...@@ -16,6 +17,7 @@ def simple_img_conv_pool(input,
input=input, input=input,
num_filters=num_filters, num_filters=num_filters,
filter_size=filter_size, filter_size=filter_size,
param_attr=param_attr,
act=act, act=act,
main_program=main_program, main_program=main_program,
startup_program=startup_program) startup_program=startup_program)
...@@ -36,6 +38,7 @@ def img_conv_group(input, ...@@ -36,6 +38,7 @@ def img_conv_group(input,
conv_padding=1, conv_padding=1,
conv_filter_size=3, conv_filter_size=3,
conv_act=None, conv_act=None,
param_attr=None,
conv_with_batchnorm=False, conv_with_batchnorm=False,
conv_batchnorm_drop_rate=None, conv_batchnorm_drop_rate=None,
pool_stride=1, pool_stride=1,
...@@ -57,6 +60,7 @@ def img_conv_group(input, ...@@ -57,6 +60,7 @@ def img_conv_group(input,
conv_padding = __extend_list__(conv_padding) conv_padding = __extend_list__(conv_padding)
conv_filter_size = __extend_list__(conv_filter_size) conv_filter_size = __extend_list__(conv_filter_size)
param_attr = __extend_list__(param_attr)
conv_with_batchnorm = __extend_list__(conv_with_batchnorm) conv_with_batchnorm = __extend_list__(conv_with_batchnorm)
conv_batchnorm_drop_rate = __extend_list__(conv_batchnorm_drop_rate) conv_batchnorm_drop_rate = __extend_list__(conv_batchnorm_drop_rate)
...@@ -70,6 +74,7 @@ def img_conv_group(input, ...@@ -70,6 +74,7 @@ def img_conv_group(input,
num_filters=conv_num_filter[i], num_filters=conv_num_filter[i],
filter_size=conv_filter_size[i], filter_size=conv_filter_size[i],
padding=conv_padding[i], padding=conv_padding[i],
param_attr=param_attr[i],
act=local_conv_act, act=local_conv_act,
main_program=main_program, main_program=main_program,
startup_program=startup_program) startup_program=startup_program)
...@@ -101,6 +106,7 @@ def img_conv_group(input, ...@@ -101,6 +106,7 @@ def img_conv_group(input,
def sequence_conv_pool(input, def sequence_conv_pool(input,
num_filters, num_filters,
filter_size, filter_size,
param_attr=None,
act="sigmoid", act="sigmoid",
pool_type="max", pool_type="max",
main_program=None, main_program=None,
...@@ -109,6 +115,7 @@ def sequence_conv_pool(input, ...@@ -109,6 +115,7 @@ def sequence_conv_pool(input,
input=input, input=input,
num_filters=num_filters, num_filters=num_filters,
filter_size=filter_size, filter_size=filter_size,
param_attr=param_attr,
act=act, act=act,
main_program=main_program, main_program=main_program,
startup_program=startup_program) startup_program=startup_program)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册