未验证 提交 33a995ce 编写于 作者: xiebaiyuan's avatar xiebaiyuan 提交者: GitHub

develop nearestinterp op fix bug in element add close #1647 (#1648)

* develop nearestinterp op  fix bug in element add close #1647

* remove log
remove log
上级 3cf63ef9
......@@ -71,7 +71,9 @@ const char *G_OP_TYPE_GRU = "gru";
const char *G_OP_TYPE_GRU_UNIT = "gru_unit";
const char *G_OP_TYPE_CRF = "crf_decoding";
const char *G_OP_TYPE_BILINEAR_INTERP = "bilinear_interp";
const char *G_OP_TYPE_NEAREST_INTERP = "nearest_interp";
const char *G_OP_TYPE_FLATTEN = "flatten";
const char *G_OP_TYPE_FLATTEN2 = "flatten2";
const char *G_OP_TYPE_SHAPE = "shape";
const char *G_OP_TYPE_SUM = "sum";
const char *G_OP_TYPE_TOP_K = "top_k";
......@@ -177,7 +179,9 @@ std::unordered_map<
{"Gate", "ResetHiddenPrev", "Hidden"}}},
{G_OP_TYPE_CRF, {{"Emission", "Transition", "Label"}, {"ViterbiPath"}}},
{G_OP_TYPE_BILINEAR_INTERP, {{"OutSize", "X"}, {"Out"}}},
{G_OP_TYPE_NEAREST_INTERP, {{"OutSize", "X"}, {"Out"}}},
{G_OP_TYPE_FLATTEN, {{"X"}, {"Out"}}},
{G_OP_TYPE_FLATTEN2, {{"X"}, {"Out"}}},
{G_OP_TYPE_SHAPE, {{"Input"}, {"Out"}}},
{G_OP_TYPE_CONV_TRANSPOSE, {{"Input"}, {"Output"}}},
{G_OP_TYPE_SUM, {{"X"}, {"Out"}}},
......
......@@ -134,6 +134,12 @@ extern const char *G_OP_TYPE_FUSION_CONV_BN_RELU;
extern const char *G_OP_TYPE_GRU;
extern const char *G_OP_TYPE_GRU_UNIT;
extern const char *G_OP_TYPE_CRF;
extern const char *G_OP_TYPE_BILINEAR_INTERP;
extern const char *G_OP_TYPE_NEAREST_INTERP;
extern const char *G_OP_TYPE_FLATTEN;
extern const char *G_OP_TYPE_FLATTEN2;
extern const char *G_OP_TYPE_SHAPE;
extern const char *G_OP_TYPE_LRN;
extern const char *G_OP_TYPE_MUL;
extern const char *G_OP_TYPE_MULTICLASS_NMS;
......
......@@ -73,6 +73,12 @@ LOAD_OP2(batch_norm, CPU, GPU_CL);
#ifdef BILINEAR_INTERP_OP
LOAD_OP1(bilinear_interp, CPU);
#endif
#ifdef NEAREST_INTERP_OP
LOAD_OP1(nearest_interp, CPU);
#endif
#ifdef LEAKY_RELU_OP
LOAD_OP1(leaky_relu, CPU);
#endif
#ifdef BOXCODER_OP
LOAD_OP2(box_coder, CPU, GPU_CL);
#endif
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef NEAREST_INTERP_OP
#include "operators/kernel/nearest_interp_kernel.h"
namespace paddle_mobile {
namespace operators {
template <>
bool NearestInterpolationKernel<CPU, float>::Init(
NearestInterpolationParam<CPU>* param) {
return true;
}
template <>
void NearestInterpolationKernel<CPU, float>::Compute(
const NearestInterpolationParam<CPU>& param) {
auto out_dims = param.Out()->dims();
auto* input = param.InputX()->data<float>();
auto out_size_t = param.InputOutPutSize();
int out_h = param.OutH();
int out_w = param.OutW();
if (out_size_t != nullptr) {
auto out_size_data = out_size_t->data<int>();
out_h = out_size_data[0];
out_w = out_size_data[1];
}
auto* output = param.Out()->mutable_data<float>(
{out_dims[0], out_dims[1], out_h, out_w});
auto batch_size = param.InputX()->dims()[0];
auto channels = param.InputX()->dims()[1];
auto in_h = param.InputX()->dims()[2];
auto in_w = param.InputX()->dims()[3];
auto in_hw = in_h * in_w;
auto out_hw = out_h * out_w;
auto in_chw = channels * in_hw;
auto out_chw = channels * out_hw;
float ratio_h =
(out_h > 1) ? static_cast<float>(in_h - 1) / (out_h - 1) : 0.f;
float ratio_w =
(out_w > 1) ? static_cast<float>(in_w - 1) / (out_w - 1) : 0.f;
if (in_h == out_h && in_w == out_w) {
memcpy(output, input, param.InputX()->numel() * sizeof(float));
} else {
for (int k = 0; k < batch_size; ++k) { // loop for batches
for (int i = 0; i < out_h; ++i) { // loop for images
int h = ratio_h * i + 0.5f;
for (int j = 0; j < out_w; ++j) {
int w = ratio_w * j + 0.5f;
// calculate four position for bilinear interpolation
const float* in_pos = &input[k * in_chw + h * in_w + w];
float* out_pos = &output[k * out_chw + i * out_w + j];
for (int c = 0; c < channels; ++c) { // loop for channels
// nearest interpolation
out_pos[0] = in_pos[0];
in_pos += in_hw;
out_pos += out_hw;
}
}
}
}
}
}
} // namespace operators
} // namespace paddle_mobile
#endif
......@@ -31,7 +31,6 @@ inline void ElementwiseAddCompute(const ElementwiseAddParam<CPU> &param) {
const framework::Tensor *input_y = param.InputY();
framework::Tensor *output = param.Out();
int axis = param.Axis();
math::AddElememtWise<IDENTITY>(input_x, input_y, axis, output);
}
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef NEAREST_INTERP_OP
#pragma once
#include <vector>
#include "framework/operator.h"
#include "operators/op_param.h"
namespace paddle_mobile {
namespace operators {
template <typename DeviceType, typename T>
class NearestInterpolationKernel
: public framework::OpKernelBase<DeviceType,
NearestInterpolationParam<DeviceType>> {
public:
void Compute(const NearestInterpolationParam<DeviceType>& param);
bool Init(NearestInterpolationParam<DeviceType>* param);
};
} // namespace operators
} // namespace paddle_mobile
#endif
......@@ -214,50 +214,87 @@ void AddElememtWise(const framework::Tensor *input,
float *output_data = output->mutable_data<float>();
if (x_dims == y_dims) {
int remain_start = 0;
#if defined(__ARM_NEON__) || defined(__ARM_NEON)
remain_start = input->numel() & 0xfffffffc;
#pragma omp parallel for
for (int i = 0; i < input->numel() - 15; i += 16) {
float32x4_t r0 = vld1q_f32(input_data);
float32x4_t r1 = vld1q_f32(input_data + 4);
float32x4_t r2 = vld1q_f32(input_data + 8);
float32x4_t r3 = vld1q_f32(input_data + 12);
float32x4_t b0 = vld1q_f32(bias_data);
float32x4_t b1 = vld1q_f32(bias_data + 4);
float32x4_t b2 = vld1q_f32(bias_data + 8);
float32x4_t b3 = vld1q_f32(bias_data + 12);
r0 = vaddq_f32(r0, b0);
r1 = vaddq_f32(r1, b1);
r2 = vaddq_f32(r2, b2);
r3 = vaddq_f32(r3, b3);
r0 = math::vActiveq_f32<Act>(r0);
r1 = math::vActiveq_f32<Act>(r1);
r2 = math::vActiveq_f32<Act>(r2);
r3 = math::vActiveq_f32<Act>(r3);
vst1q_f32(output_data, r0);
vst1q_f32(output_data + 4, r1);
vst1q_f32(output_data + 8, r2);
vst1q_f32(output_data + 12, r3);
input_data += 16;
bias_data += 16;
output_data += 16;
}
for (int i = input->numel() & 0xfffffff0; i < input->numel() - 3; i += 4) {
float32x4_t r0 = vld1q_f32(input_data);
float32x4_t b0 = vld1q_f32(bias_data);
r0 = vaddq_f32(r0, b0);
r0 = math::vActiveq_f32<Act>(r0);
vst1q_f32(output_data, r0);
input_data += 4;
bias_data += 4;
output_data += 4;
size_t channels = 1;
size_t elementwise_num = 1;
for (int i = 0; i < y_dims.size(); ++i) {
channels *= y_dims[i];
}
#pragma omp parallel for
for (int j = 0; j < channels; ++j) {
size_t offset = (0 * channels + j) * elementwise_num;
const float *input = input_data + offset;
const float bias = bias_data[j];
float *output = output_data + offset;
#if 0
int loop = elementwise_num >> 0x4;
int remain = elementwise_num & 0xF;
float32x4_t rb = vdupq_n_f32(bias);
for (int k = 0; k < loop; ++k) {
float32x4_t r0 = vld1q_f32(input);
float32x4_t r1 = vld1q_f32(input + 4);
float32x4_t r2 = vld1q_f32(input + 8);
float32x4_t r3 = vld1q_f32(input + 12);
r0 = vaddq_f32(r0, rb);
r1 = vaddq_f32(r1, rb);
r2 = vaddq_f32(r2, rb);
r3 = vaddq_f32(r3, rb);
r0 = math::vActiveq_f32<Act>(r0);
r1 = math::vActiveq_f32<Act>(r1);
r2 = math::vActiveq_f32<Act>(r2);
r3 = math::vActiveq_f32<Act>(r3);
vst1q_f32(output, r0);
vst1q_f32(output + 4, r1);
vst1q_f32(output + 8, r2);
vst1q_f32(output + 12, r3);
input += 16;
output += 16;
}
if (remain >= 8) {
float32x4_t r0 = vld1q_f32(input);
float32x4_t r1 = vld1q_f32(input + 4);
r0 = vaddq_f32(r0, rb);
r1 = vaddq_f32(r1, rb);
r0 = math::vActiveq_f32<Act>(r0);
r1 = math::vActiveq_f32<Act>(r1);
vst1q_f32(output, r0);
vst1q_f32(output + 4, r1);
input += 8;
output += 8;
remain -= 8;
}
if (remain >= 4) {
float32x4_t r0 = vld1q_f32(input);
r0 = vaddq_f32(r0, rb);
r0 = math::vActiveq_f32<Act>(r0);
vst1q_f32(output, r0);
input += 4;
output += 4;
remain -= 4;
}
if (remain > 0) {
float32x4_t r0 = vld1q_f32(input);
r0 = vaddq_f32(r0, rb);
r0 = math::vActiveq_f32<Act>(r0);
switch (remain) {
case 1:
vst1q_lane_f32(output, r0, 0);
break;
case 2:
vst1_f32(output, vget_low_f32(r0));
break;
case 3:
vst1_f32(output, vget_low_f32(r0));
vst1q_lane_f32(output, r0, 2);
break;
}
}
#else
for (int k = 0; k < elementwise_num; ++k) {
output[k] = math::Active<Act>(input[k] + bias);
}
#endif // __ARM_NEON__
for (int i = remain_start; i < input->numel(); ++i) {
output_data[i] = math::Active<Act>(input_data[i] + bias_data[i]);
}
} else {
// axis = -1 represent the last dimensions.
int dim = (axis == -1 ? x_dims.size() - y_dims.size() : axis);
......@@ -274,7 +311,7 @@ void AddElememtWise(const framework::Tensor *input,
elementwise_num *= x_dims[i];
}
#pragma omp parallel for collapse(2)
#pragma omp parallel for collapse(2)
for (int i = 0; i < batch; ++i) {
for (int j = 0; j < channels; ++j) {
size_t offset = (i * channels + j) * elementwise_num;
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef NEAREST_INTERP_OP
#include "operators/nearest_interp_op.h"
#include <vector>
namespace paddle_mobile {
namespace operators {
template <typename DeviceType, typename T>
void NearestInterpolationOp<DeviceType, T>::InferShape() const {
PADDLE_MOBILE_ENFORCE(this->param_.InputX() != nullptr,
"Input(X) of BilinearInterOp should not be null.");
PADDLE_MOBILE_ENFORCE(this->param_.Out() != nullptr,
"Output(Out) of BilinearInterOp should not be null.");
auto dim_x = this->param_.InputX()->dims(); // NCHW format
int out_h = this->param_.OutH();
int out_w = this->param_.OutW();
PADDLE_MOBILE_ENFORCE(dim_x.size() == 4, "X's dimension must be 4");
if (this->param_.InputOutPutSize() != nullptr) {
auto out_size_dim = this->param_.InputOutPutSize()->dims();
PADDLE_MOBILE_ENFORCE(out_size_dim.size() == 1,
"OutSize's dimension size must be 1");
PADDLE_MOBILE_ENFORCE(out_size_dim[0] == 2, "OutSize's dim[0] must be 2");
}
std::vector<int64_t> dim_out({dim_x[0], dim_x[1], out_h, out_w});
this->param_.Out()->Resize(framework::make_ddim(dim_out));
}
} // namespace operators
} // namespace paddle_mobile
namespace ops = paddle_mobile::operators;
#ifdef PADDLE_MOBILE_CPU
REGISTER_OPERATOR_CPU(nearest_interp, ops::NearestInterpolationOp);
#endif
#endif
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef NEAREST_INTERP_OP
#pragma once
#include <string>
#include "framework/operator.h"
#include "operators/kernel/nearest_interp_kernel.h"
#include "operators/op_param.h"
namespace paddle_mobile {
namespace operators {
using paddle_mobile::framework::Tensor;
template <typename DeviceType, typename T>
class NearestInterpolationOp
: public framework::OperatorWithKernel<
DeviceType, NearestInterpolationParam<DeviceType>,
operators::NearestInterpolationKernel<DeviceType, T>> {
public:
NearestInterpolationOp(const std::string &type, const VariableNameMap &inputs,
const VariableNameMap &outputs,
const framework::AttributeMap &attrs,
framework::Scope *scope)
: framework::OperatorWithKernel<
DeviceType, NearestInterpolationParam<DeviceType>,
operators::NearestInterpolationKernel<DeviceType, T>>(
type, inputs, outputs, attrs, scope) {}
void InferShape() const override;
};
} // namespace operators
} // namespace paddle_mobile
#endif
......@@ -2849,6 +2849,38 @@ class BilinearInterpParam : public OpParam {
};
#endif
#ifdef NEAREST_INTERP_OP
template <typename Dtype>
class NearestInterpolationParam : public OpParam {
typedef typename DtypeTensorTrait<Dtype>::gtype GType;
typedef typename DtypeTensorTrait<Dtype>::rtype RType;
public:
NearestInterpolationParam(const VariableNameMap &inputs,
const VariableNameMap &outputs,
const AttributeMap &attrs, Scope *scope)
: OpParam(inputs, outputs, attrs, scope) {
input_x_ = InputXFrom<GType>(inputs, *scope);
input_outsize_ = InputOutSizeFrom<GType>(inputs, *scope);
out_ = OutFrom<GType>(outputs, *scope);
out_h_ = GetAttr<int>("out_h", attrs);
out_w_ = GetAttr<int>("out_w", attrs);
}
const GType *InputX() const { return input_x_; }
const GType *InputOutPutSize() const { return input_outsize_; }
GType *Out() const { return out_; }
int OutH() const { return out_h_; }
int OutW() const { return out_w_; }
private:
GType *input_x_;
GType *input_outsize_;
GType *out_;
int out_h_;
int out_w_;
};
#endif
#ifdef SHAPE_OP
template <typename Dtype>
class ShapeParam : public OpParam {
......
......@@ -363,6 +363,7 @@ if(NOT FOUND_MATCH)
set(PAD2D_OP ON)
set(ONE_HOT_OP ON)
set(ASSIGN_VALUE_OP ON)
set(NEAREST_INTERP_OP ON)
set(LEAKY_RELU_OP ON)
endif()
......@@ -695,4 +696,7 @@ if (ASSIGN_VALUE_OP)
endif()
if (LEAKY_RELU_OP)
add_definitions(-DLEAKY_RELU_OP)
endif()
\ No newline at end of file
endif()
if (NEAREST_INTERP_OP)
add_definitions(-DNEAREST_INTERP_OP)
endif()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册