未验证 提交 da5326e7 编写于 作者: R Ruilong Liu 提交者: GitHub

Merge pull request #502 from itminner/face_op

add resize prelu scale op
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef PRELU_OP
#include "operators/kernel/prelu_kernel.h"
#include <operators/math/transform.h>
namespace paddle_mobile {
namespace operators {
template <typename T>
struct PReluFunctor {
explicit PReluFunctor(float slope) { this->slope_ = slope; }
inline T operator()(T in) const { return in > 0 ? in : in * slope_; }
float slope_ = 0.0f;
};
/*
* @b 特化到具体平台的实现, param 从 op 层传入
* */
template <>
void PReluKernel<CPU, float>::Compute(const PReluParam &param) const {
const auto *input_x = param.InputX();
auto *input_x_ptr = input_x->data<float>();
auto *out = param.Out();
auto *out_ptr = out->mutable_data<float>();
if (param.Slopes().size() == 1) {
PReluFunctor<float> func_(param.Slopes()[0]);
math::Transform trans;
trans(input_x_ptr, input_x_ptr + input_x->numel(), out_ptr, func_);
} else if (param.Slopes().size() > 1) {
const int dim_size = input_x->dims().size();
switch (dim_size) {
case 0:
break;
case 1: {
const int input_width = input_x->dims()[0];
math::Transform trans;
#pragma omp parallel for
for (int w = 0; w < input_width; ++w) {
out_ptr[w] = input_x_ptr[w] * param.Slopes()[w];
}
} break;
case 2: {
const int input_height = input_x->dims()[0];
const int input_width = input_x->dims()[1];
math::Transform trans;
#pragma omp parallel for
for (int h = 0; h < input_height; ++h) {
PReluFunctor<float> func_(param.Slopes()[h]);
const float *ptr = input_x_ptr + h * input_width;
float *optr = out_ptr + +h * input_width;
trans(ptr, ptr + input_width, optr, func_);
}
} break;
case 3: {
const int chan_size = input_x->dims()[0];
const int input_height = input_x->dims()[1];
const int input_width = input_x->dims()[2];
math::Transform trans;
#pragma omp parallel for
for (int c = 0; c < chan_size; ++c) {
PReluFunctor<float> func_(param.Slopes()[c]);
int size = input_height * input_width;
const float *ptr = input_x_ptr + c * size;
float *optr = out_ptr + c * size;
trans(ptr, ptr + size, optr, func_);
}
} break;
case 4:
default: {
const int batch_size = input_x->dims()[0];
const int chan_size = input_x->dims()[1];
const int input_height = input_x->dims()[2];
const int input_width = input_x->dims()[3];
math::Transform trans;
#pragma omp parallel for
for (int b = 0; b < batch_size; ++b) {
for (int c = 0; c < chan_size; ++c) {
PReluFunctor<float> func_(param.Slopes()[c]);
int size = input_height * input_width;
const float *ptr = input_x_ptr + b * c * size;
float *optr = out_ptr + +b * c * size;
trans(ptr, ptr + size, optr, func_);
}
}
} // case 3,default
break;
}
}
}
} // namespace operators
} // namespace paddle_mobile
#endif
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef RESIZE_OP
#include "operators/kernel/resize_kernel.h"
#include <cmath>
namespace paddle_mobile {
namespace operators {
void BiLinearResizeTensor(const float* src, const int src_height,
const int src_width, float* dst, const int dst_height,
const int dst_width) {
const float scale_w = src_width / (float)dst_width;
const float scale_h = src_height / (float)dst_height;
float* dst_data = dst;
const float* src_data = src;
for (int dst_h = 0; dst_h < dst_height; ++dst_h) {
float fh = dst_h * scale_h;
int src_h = std::floor(fh);
fh -= src_h;
const float w_h0 = std::abs((float)1.0 - fh);
const float w_h1 = std::abs(fh);
const int dst_offset_1 = dst_h * dst_width;
const int src_offset_1 = src_h * src_width;
float* dst_data_ptr = dst_data + dst_offset_1;
for (int dst_w = 0; dst_w < dst_width; ++dst_w) {
float fw = dst_w * scale_w;
int src_w = std::floor(fw);
fw -= src_w;
const float w_w0 = std::abs((float)1.0 - fw);
const float w_w1 = std::abs(fw);
float dst_value = 0;
const int src_idx = src_offset_1 + src_w;
dst_value += (w_h0 * w_w0 * src_data[src_idx]);
int flag = 0;
if (src_w + 1 < src_width) {
dst_value += (w_h0 * w_w1 * src_data[src_idx + 1]);
++flag;
}
if (src_h + 1 < src_height) {
dst_value += (w_h1 * w_w0 * src_data[src_idx + src_width]);
++flag;
}
if (flag > 1) {
dst_value += (w_h1 * w_w1 * src_data[src_idx + src_width + 1]);
// ++flag;
}
*(dst_data_ptr++) = dst_value;
}
}
}
void ResizeTensor(const Tensor* src, const int src_n, const int src_c,
Tensor* dst, const int dst_n, const int dst_c) {
framework::DDim in_dims = src->dims();
const int src_chans = in_dims[1];
const int src_height = in_dims[2];
const int src_width = in_dims[3];
const int src_offset = (src_n * src_chans + src_c) * src_height * src_width;
framework::DDim out_dims = dst->dims();
const int dst_chans = out_dims[1];
const int dst_height = out_dims[2];
const int dst_width = out_dims[3];
const int dst_offset = (dst_n * dst_chans + dst_c) * dst_height * dst_width;
const auto* src_ptr = src->data<float>();
auto* dst_ptr = dst->data<float>();
const auto* src_data = &(src_ptr[src_offset]);
auto* dst_data = &(dst_ptr[dst_offset]);
BiLinearResizeTensor(src_data, src_height, src_width, dst_data, dst_height,
dst_width);
}
void ResizeTensor(const Tensor* src, Tensor* dst) {
framework::DDim in_dims = src->dims();
framework::DDim out_dims = dst->dims();
PADDLE_MOBILE_ENFORCE(in_dims[0] == out_dims[0],
"src tensor batch num not equal to dst tensor");
PADDLE_MOBILE_ENFORCE(in_dims[1] == out_dims[1],
"src tensor channel num not equal to dst tensor");
for (int n = 0, batch_num = in_dims[0]; n < batch_num; ++n) {
for (int c = 0, chan_num = in_dims[1]; c < chan_num; ++c) {
ResizeTensor(src, n, c, dst, n, c);
}
}
}
template <>
void ResizeKernel<CPU, float>::Compute(const ResizeParam& param) const {
const auto* input_x = param.InputX();
const auto& input_x_dims = input_x->dims();
auto* out = param.Out();
framework::DDim out_dims = CalOutputShape(param);
out->Resize(out_dims);
ResizeTensor(input_x, out);
}
} // namespace operators
} // namespace paddle_mobile
#endif
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef SCALE_OP
#include "operators/kernel/scale_kernel.h"
namespace paddle_mobile {
namespace operators {
/*
* @b 特化到具体平台的实现, param 从 op 层传入
* */
template <>
void ScaleKernel<CPU, float>::Compute(const ScaleParam &param) const {
const auto *input_x = param.InputX();
auto *input_x_ptr = input_x->data<float>();
auto *out = param.Out();
auto *out_ptr = out->mutable_data<float>();
const vector<float> scales = param.Scales();
bool has_bias = param.HasBias();
const int dim_size = input_x->dims().size();
switch (dim_size) {
case 1: {
const int input_width = input_x->dims()[0];
if (has_bias) {
const vector<float> biases = param.Biases();
#pragma omp parallel for
for (int w = 0; w < input_width; w++) {
out_ptr[w] = input_x_ptr[w] * scales[w] + biases[w];
}
} else {
#pragma omp parallel for
for (int w = 0; w < input_width; w++) {
out_ptr[w] = input_x_ptr[w] * scales[w];
}
}
} break;
case 2: {
const int input_height = input_x->dims()[0];
const int input_width = input_x->dims()[1];
if (has_bias) {
const vector<float> biases = param.Biases();
#pragma omp parallel for
for (int h = 0; h < input_height; ++h) {
const float *iptr = input_x_ptr + h * input_width;
float *optr = out_ptr + h * input_width;
for (int w = 0; w < input_width; ++w) {
optr[w] = iptr[w] * scales[w] + biases[w];
}
}
} else {
#pragma omp parallel for
for (int h = 0; h < input_height; ++h) {
const float *iptr = input_x_ptr + h * input_width;
float *optr = out_ptr + h * input_width;
for (int w = 0; w < input_width; ++w) {
optr[w] = iptr[w] * scales[w];
}
}
}
} break;
case 3: {
const int chan_size = input_x->dims()[0];
const int input_height = input_x->dims()[1];
const int input_width = input_x->dims()[2];
int size = input_width * input_height;
if (has_bias) {
const vector<float> biases = param.Biases();
#pragma omp parallel for
for (int c = 0; c < chan_size; ++c) {
const float *iptr = input_x_ptr + c * size;
float *optr = out_ptr + c * size;
for (int i = 0; i < size; ++i) {
optr[i] = iptr[i] * scales[c] + biases[c];
}
}
} else {
#pragma omp parallel for
for (int c = 0; c < chan_size; ++c) {
const float *iptr = input_x_ptr + c * size;
float *optr = out_ptr + c * size;
for (int i = 0; i < size; ++i) {
optr[i] = iptr[i] * scales[c];
}
}
}
} break;
case 4: {
const int batch_size = input_x->dims()[0];
const int chan_size = input_x->dims()[0];
const int input_height = input_x->dims()[1];
const int input_width = input_x->dims()[2];
int size = input_width * input_height;
if (has_bias) {
const vector<float> biases = param.Biases();
#pragma omp parallel for
for (int b = 0; b < batch_size; ++b) {
for (int c = 0; c < chan_size; ++c) {
const float *iptr = input_x_ptr + b * c * size;
float *optr = out_ptr + b * c * size;
for (int i = 0; i < size; ++i) {
optr[i] = iptr[i] * scales[c] + biases[c];
}
}
}
} else {
#pragma omp parallel for
for (int b = 0; b < batch_size; ++b) {
for (int c = 0; c < chan_size; ++c) {
const float *iptr = input_x_ptr + b * c * size;
float *optr = out_ptr + b * c * size;
for (int i = 0; i < size; ++i) {
optr[i] = iptr[i] * scales[c];
}
}
}
}
} break;
default:
break;
}
}
} // namespace operators
} // namespace paddle_mobile
#endif
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef SLICE_OP
#include "operators/kernel/slice_kernel.h"
namespace paddle_mobile {
namespace operators {}
} // namespace paddle_mobile
#endif
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "framework/operator.h"
#include "operators/op_param.h"
#pragma once;
namespace paddle_mobile {
namespace operators {
template <typename DeviceType, typename T>
class PReluKernel : public framework::OpKernelBase<DeviceType, PReluParam> {
public:
void Compute(const PReluParam& param) const;
};
} // namespace operators
} // namespace paddle_mobile
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef RESIZE_OP
#pragma once
#include <vector>
#include "framework/operator.h"
#include "operators/op_param.h"
namespace paddle_mobile {
namespace operators {
inline framework::DDim CalOutputShape(const ResizeParam &param) {
const auto *input_x = param.InputX();
const auto &input_x_dims = input_x->dims();
auto *out = param.Out();
framework::DDim out_dims = out->dims();
const auto *input_shape = param.InputShape();
if (input_shape) {
auto *shape_data = input_shape->data<int>();
framework::Tensor cpu_shape_tensor;
auto shape =
std::vector<int>(shape_data, shape_data + input_shape->numel());
const int in_batch_size = input_x->dims()[0];
const int in_chan_size = input_x->dims()[1];
const int in_height = input_x->dims()[2];
const int in_width = input_x->dims()[3];
int out_height = 0;
int out_width = 0;
bool is_pyramid_test = param.IsPyramidTest();
if (is_pyramid_test == false) {
out_height = param.Height();
out_width = param.Width();
PADDLE_MOBILE_ENFORCE(out_height > 0, "output height is required");
PADDLE_MOBILE_ENFORCE(out_width > 0, "output width is required");
} else {
float out_height_scale = param.OutHeightScale();
float out_width_scale = param.OutWidthScale();
PADDLE_MOBILE_ENFORCE(out_height_scale > 0,
"output height scale is required");
PADDLE_MOBILE_ENFORCE(out_width_scale > 0,
"output width scale is required");
out_height = int(out_height_scale * in_height);
out_width = int(out_width_scale * in_width);
}
out_dims = framework::make_ddim(
{in_batch_size, in_chan_size, in_height, in_width});
}
return out_dims;
}
template <typename DeviceType, typename T>
class ResizeKernel : public framework::OpKernelBase<DeviceType, ResizeParam> {
public:
void Compute(const ResizeParam &param) const;
};
} // namespace operators
} // namespace paddle_mobile
#endif
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "framework/operator.h"
#include "operators/op_param.h"
#pragma once;
namespace paddle_mobile {
namespace operators {
template <typename DeviceType, typename T>
class ScaleKernel : public framework::OpKernelBase<DeviceType, ScaleParam> {
public:
void Compute(const ScaleParam& param) const;
};
} // namespace operators
} // namespace paddle_mobile
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "framework/operator.h"
#include "operators/op_param.h"
#pragma once;
namespace paddle_mobile {
namespace operators {
template <typename DeviceType, typename T>
class SliceKernel : public framework::OpKernelBase<DeviceType, SliceParam> {
public:
void Compute(const SliceParam& param) const {}
};
} // namespace operators
} // namespace paddle_mobile
......@@ -715,6 +715,123 @@ class ReshapeParam : public OpParam {
};
#endif
#ifdef SCALE_OP
class ScaleParam : public OpParam {
public:
ScaleParam(const VariableNameMap &inputs, const VariableNameMap &outputs,
const AttributeMap &attrs, const Scope &scope) {
input_x_ = InputXFrom<LoDTensor>(inputs, scope);
input_bias_ = InputBiasFrom<framework::LoDTensor>(inputs, scope);
out_ = OutFrom<LoDTensor>(outputs, scope);
inplace_ = GetAttr<bool>("inplace", attrs);
has_bias_ = GetAttr<bool>("has_bias", attrs);
scales_ = GetAttr<vector<float>>("scales", attrs);
biases_ = GetAttr<vector<float>>("biases", attrs);
}
const Tensor *InputX() const { return input_x_; }
const Tensor *InputBias() const { return input_bias_; }
Tensor *Out() const { return out_; }
const bool &Inplace() const { return inplace_; }
const bool &HasBias() const { return has_bias_; }
const vector<float> &Scales() const { return scales_; }
const vector<float> &Biases() const { return biases_; }
private:
Tensor *input_x_;
Tensor *input_bias_;
Tensor *out_;
bool inplace_;
bool has_bias_;
vector<float> scales_;
vector<float> biases_;
};
#endif
#ifdef SLICE_OP
class SliceParam : public OpParam {
public:
SliceParam(const VariableNameMap &inputs, const VariableNameMap &outputs,
const AttributeMap &attrs, const Scope &scope) {
input_x_ = InputXFrom<LoDTensor>(inputs, scope);
input_shape_ = InputShapeFrom<LoDTensor>(inputs, scope);
out_ = OutFrom<LoDTensor>(outputs, scope);
axis_ = GetAttr<int>("axis", attrs);
slice_points_ = GetAttr<vector<int>>("slice_points", attrs);
inplace_ = GetAttr<bool>("inplace", attrs);
}
const Tensor *InputX() const { return input_x_; }
const Tensor *InputShape() const { return input_shape_; }
Tensor *Out() const { return out_; }
const int &Axis() const { return axis_; }
const vector<int> &SlicePoints() const { return slice_points_; }
const bool &Inplace() const { return inplace_; }
private:
Tensor *input_x_;
Tensor *input_shape_;
Tensor *out_;
int axis_;
vector<int> slice_points_;
bool inplace_;
};
#endif
#ifdef RESIZE_OP
class ResizeParam : public OpParam {
public:
ResizeParam(const VariableNameMap &inputs, const VariableNameMap &outputs,
const AttributeMap &attrs, const Scope &scope) {
input_x_ = InputXFrom<LoDTensor>(inputs, scope);
input_shape_ = InputShapeFrom<LoDTensor>(inputs, scope);
out_ = OutFrom<LoDTensor>(outputs, scope);
is_pyramid_test_ = GetAttr<bool>("is_pyramid_test", attrs);
height_ = GetAttr<int>("height", attrs);
width_ = GetAttr<int>("width", attrs);
out_height_scale_ = GetAttr<float>("out_height_scale", attrs);
out_width_scale_ = GetAttr<float>("out_width_scale", attrs);
}
const Tensor *InputX() const { return input_x_; }
const Tensor *InputShape() const { return input_shape_; }
Tensor *Out() const { return out_; }
const bool &IsPyramidTest() const { return is_pyramid_test_; }
const int &Height() const { return height_; }
const int &Width() const { return width_; }
const float &OutHeightScale() const { return out_height_scale_; }
const float &OutWidthScale() const { return out_width_scale_; }
private:
Tensor *input_x_;
Tensor *input_shape_;
Tensor *out_;
bool is_pyramid_test_;
int height_;
int width_;
float out_height_scale_;
float out_width_scale_;
};
#endif
#ifdef RELU_OP
/*
* @b op 层实例化好这个 param 传递给 kernel 层使用
......@@ -737,6 +854,27 @@ class ReluParam : public OpParam {
};
#endif
#ifdef PRELU_OP
class PReluParam : public OpParam {
public:
PReluParam(const VariableNameMap &inputs, const VariableNameMap &outputs,
const AttributeMap &attrs, const Scope &scope) {
input_x_ = InputXFrom<LoDTensor>(inputs, scope);
out_ = OutFrom<LoDTensor>(outputs, scope);
slopes_ = GetAttr<vector<float>>("slopes", attrs);
}
const Tensor *InputX() const { return input_x_; }
Tensor *Out() const { return out_; }
const vector<float> &Slopes() const { return slopes_; }
private:
Tensor *input_x_;
Tensor *out_;
vector<float> slopes_;
};
#endif
#ifdef FUSION_FC_OP
class FusionFcParam : public OpParam {
public:
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef PRELU_OP
#include "operators/prelu_op.h"
namespace paddle_mobile {
namespace operators {
template <typename Dtype, typename T>
void PReluOp<Dtype, T>::InferShape() const {
auto input_dims = this->param_.InputX()->dims();
this->param_.Out()->Resize(input_dims);
}
template class PReluOp<CPU, float>;
} // namespace operators
} // namespace paddle_mobile
/*
* @b 每一个 op 都需要注册一下的,
* USE_OP的参数 和 REGISTER_OPERATOR的第一个参数
* 都是需要和model中类型对应起来的
* */
namespace ops = paddle_mobile::operators;
#ifdef PADDLE_MOBILE_CPU
USE_OP_CPU(prelu);
REGISTER_OPERATOR_CPU(prelu, ops::PReluOp);
#endif
#ifdef PADDLE_MOBILE_MALI_GPU
USE_OP_MALI_GPU(prelu);
REGISTER_OPERATOR_MALI_GPU(prelu, ops::PReluOp);
#endif
#ifdef PADDLE_MOBILE_FPGA
#endif
#endif
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef PRELU_OP
#pragma once
#include <string>
#include "framework/operator.h"
#include "operators/kernel/prelu_kernel.h"
#include "operators/op_param.h"
namespace paddle_mobile {
namespace operators {
using paddle_mobile::framework::Tensor;
template <typename DeviceType, typename T>
class PReluOp
: public framework::OperatorWithKernel<
DeviceType, PReluParam, operators::PReluKernel<DeviceType, T>> {
public:
PReluOp(const std::string &type, const VariableNameMap &inputs,
const VariableNameMap &outputs, const framework::AttributeMap &attrs,
std::shared_ptr<framework::Scope> scope)
: framework::OperatorWithKernel<DeviceType, PReluParam,
operators::PReluKernel<DeviceType, T>>(
type, inputs, outputs, attrs, scope) {}
using framework::OperatorWithKernel<
DeviceType, PReluParam,
operators::PReluKernel<DeviceType, T>>::OperatorWithKernel;
void InferShape() const override;
protected:
};
} // namespace operators
} // namespace paddle_mobile
#endif
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef RESIZE_OP
#include "operators/resize_op.h"
#include <vector>
namespace paddle_mobile {
namespace operators {
template <typename Dtype, typename T>
void ResizeOp<Dtype, T>::InferShape() const {
auto out_dims = CalOutputShape(this->param_);
this->param_.Out()->Resize(out_dims);
}
template class ResizeOp<CPU, float>;
} // namespace operators
} // namespace paddle_mobile
namespace ops = paddle_mobile::operators;
#ifdef PADDLE_MOBILE_CPU
USE_OP_CPU(resize);
REGISTER_OPERATOR_CPU(resize, ops::ResizeOp);
#endif
#ifdef PADDLE_MOBILE_MALI_GPU
USE_OP_MALI_GPU(resize);
REGISTER_OPERATOR_MALI_GPU(resize, ops::ResizeOp);
#endif
#ifdef PADDLE_MOBILE_FPGA
#endif
#endif
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef RESIZE_OP
#pragma once
#include <string>
#include "framework/operator.h"
#include "operators/kernel/resize_kernel.h"
#include "operators/op_param.h"
namespace paddle_mobile {
namespace operators {
using paddle_mobile::framework::Tensor;
template <typename DeviceType, typename T>
class ResizeOp
: public framework::OperatorWithKernel<
DeviceType, ResizeParam, operators::ResizeKernel<DeviceType, T>> {
public:
ResizeOp(const std::string &type, const VariableNameMap &inputs,
const VariableNameMap &outputs, const framework::AttributeMap attrs,
std::shared_ptr<framework::Scope> scope)
: framework::OperatorWithKernel<DeviceType, ResizeParam,
operators::ResizeKernel<DeviceType, T>>(
type, inputs, outputs, attrs, scope) {}
using framework::OperatorWithKernel<
DeviceType, ResizeParam,
operators::ResizeKernel<DeviceType, T>>::OperatorWithKernel;
void InferShape() const override;
protected:
};
} // namespace operators
} // namespace paddle_mobile
#endif
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef SCALE_OP
#include "operators/scale_op.h"
#include <vector>
namespace paddle_mobile {
namespace operators {
template <typename Dtype, typename T>
void ScaleOp<Dtype, T>::InferShape() const {
auto input_dims = this->param_.InputX()->dims();
this->param_.Out()->Resize(input_dims);
}
template class ScaleOp<CPU, float>;
} // namespace operators
} // namespace paddle_mobile
namespace ops = paddle_mobile::operators;
#ifdef PADDLE_MOBILE_CPU
USE_OP_CPU(scale);
REGISTER_OPERATOR_CPU(scale, ops::ScaleOp);
#endif
#ifdef PADDLE_MOBILE_MALI_GPU
USE_OP_MALI_GPU(scale);
REGISTER_OPERATOR_MALI_GPU(scale, ops::ScaleOp);
#endif
#ifdef PADDLE_MOBILE_FPGA
#endif
#endif
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef SCALE_OP
#pragma once
#include <string>
#include "framework/operator.h"
#include "operators/kernel/scale_kernel.h"
#include "operators/op_param.h"
namespace paddle_mobile {
namespace operators {
using paddle_mobile::framework::Tensor;
template <typename DeviceType, typename T>
class ScaleOp
: public framework::OperatorWithKernel<
DeviceType, ScaleParam, operators::ScaleKernel<DeviceType, T>> {
public:
ScaleOp(const std::string &type, const VariableNameMap &inputs,
const VariableNameMap &outputs, const framework::AttributeMap &attrs,
std::shared_ptr<framework::Scope> scope)
: framework::OperatorWithKernel<DeviceType, ScaleParam,
operators::ScaleKernel<DeviceType, T>>(
type, inputs, outputs, attrs, scope) {}
using framework::OperatorWithKernel<
DeviceType, ScaleParam,
operators::ScaleKernel<DeviceType, T>>::OperatorWithKernel;
void InferShape() const override;
protected:
};
} // namespace operators
} // namespace paddle_mobile
#endif
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef SLICE_OP
#include "operators/slice_op.h"
#include <vector>
namespace paddle_mobile {
namespace operators {
template <typename Dtype, typename T>
void SliceOp<Dtype, T>::InferShape() const {
/// todo: add InputShape() detection.
}
template class SliceOp<CPU, float>;
} // namespace operators
} // namespace paddle_mobile
namespace ops = paddle_mobile::operators;
#ifdef PADDLE_MOBILE_CPU
USE_OP_CPU(slice);
REGISTER_OPERATOR_CPU(slice, ops::SliceOp);
#endif
#ifdef PADDLE_MOBILE_MALI_GPU
USE_OP_MALI_GPU(slice);
REGISTER_OPERATOR_MALI_GPU(slice, ops::SliceOp);
#endif
#ifdef PADDLE_MOBILE_FPGA
#endif
#endif
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef SLICE_OP
#pragma once
#include <string>
#include "framework/operator.h"
#include "operators/kernel/slice_kernel.h"
#include "operators/op_param.h"
namespace paddle_mobile {
namespace operators {
using paddle_mobile::framework::Tensor;
template <typename DeviceType, typename T>
class SliceOp
: public framework::OperatorWithKernel<
DeviceType, SliceParam, operators::SliceKernel<DeviceType, T>> {
public:
SliceOp(const std::string &type, const VariableNameMap &inputs,
const VariableNameMap &outputs, const framework::AttributeMap &attrs,
std::shared_ptr<framework::Scope> scope)
: framework::OperatorWithKernel<DeviceType, SliceParam,
operators::SliceKernel<DeviceType, T>>(
type, inputs, outputs, attrs, scope) {}
using framework::OperatorWithKernel<
DeviceType, SliceParam,
operators::SliceKernel<DeviceType, T>>::OperatorWithKernel;
void InferShape() const override;
protected:
};
} // namespace operators
} // namespace paddle_mobile
#endif
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "../executor_for_test.h"
#include "../test_include.h"
#include "operators/prelu_op.h"
int main() {
paddle_mobile::Loader<paddle_mobile::CPU> loader;
auto program = loader.Load(g_resnet);
PADDLE_MOBILE_ENFORCE(program.originProgram != nullptr,
"program file read fail");
Executor4Test<paddle_mobile::CPU,
paddle_mobile::operators::PReluOp<paddle_mobile::CPU, float>>
executor(program, "prelu");
// 1. input_tensors;
vector<Tensor> input_tensors;
Tensor input1;
auto input1_data = CreateInput<float>(&input1, {1, 2, 3, 4}, -1, 1);
input_tensors.push_back(input1);
// 2. input_names
vector<string> input_names({
"batch_norm_0.tmp_2",
});
// 3. output_names
vector<string> output_names({"batch_norm_0.tmp_3"});
// 4. out_dims;
vector<DDim> out_ddims;
auto out_ddim = paddle_mobile::framework::make_ddim({1, 2, 3, 4});
out_ddims.push_back(out_ddim);
auto output = executor.Predict<LoDTensor>(input_tensors, input_names,
output_names, out_ddims);
auto output0_data = output[0]->data<float>();
for (int j = 0; j < output[0]->numel(); ++j) {
DLOG << " value of output: " << output0_data[j];
}
return 0;
}
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "../test_include.h"
#include "operators/resize_op.h"
int main() {
paddle_mobile::Loader<paddle_mobile::CPU> loader;
auto program = loader.Load(std::string(g_mobilenet_ssd));
if (program.originProgram == nullptr) {
DLOG << "program read file";
}
Executor4Test<paddle_mobile::CPU,
paddle_mobile::operators::ResizeOp<paddle_mobile::CPU, float>>
executor(program, "resize");
paddle_mobile::framework::Tensor input;
SetupTensor<float>(&input, {2, 3, 3, 2}, static_cast<float>(0),
static_cast<float>(1));
auto input_ptr = input.data<float>();
auto out_ddim = paddle_mobile::framework::make_ddim({2, 9, 2});
auto output =
executor.Predict(input, "transpose_0.tmp_0", "reshape_0.tmp_0", out_ddim);
auto *output_ptr = output->data<float>();
DLOG << "input : ";
for (int j = 0; j < input.numel(); ++j) {
DLOG << " index " << j << " : " << input_ptr[j];
}
DLOG << "output : ";
for (int j = 0; j < output->numel(); ++j) {
DLOG << " index " << j << " : " << output_ptr[j];
}
return 0;
}
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "../test_include.h"
#include "operators/scale_op.h"
int main() {}
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "../test_include.h"
#include "operators/slice_op.h"
int main() {}
......@@ -160,4 +160,4 @@ else
build_error
fi
fi
fi
fi
\ No newline at end of file
......@@ -64,6 +64,10 @@ else ()
set(TRANSPOSE_OP ON)
set(FUSION_CONVADD_RELU_OP ON)
set(FUSION_CONVADDBNRELU_OP ON)
set(PRELU_OP ON)
set(RESIZE_OP ON)
set(SCALE_OP ON)
set(SLICE_OP ON)
set(DROPOUT_OP ON)
set(IM2SEQUENCE_OP ON)
# option(BATCHNORM_OP "" ON)
......@@ -151,6 +155,17 @@ endif()
if (FUSION_CONVADDBNRELU_OP)
add_definitions(-DFUSION_CONVADDBNRELU_OP)
endif()
if (PRELU_OP)
add_definitions(-DPRELU_OP)
endif()
if (RESIZE_OP)
add_definitions(-DRESIZE_OP)
endif()
if (SCALE_OP)
add_definitions(-DSCALE_OP)
endif()
if (SLICE_OP)
add_definitions(-DSLICE_OP)
if (DROPOUT_OP)
add_definitions(-DDROPOUT_OP)
endif()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册