提交 0aa19cc9 编写于 作者: Z zhangyang

Merge remote-tracking branch 'upstream/develop' into develop

......@@ -73,6 +73,8 @@ const char *G_OP_TYPE_QUANTIZE = "quantize";
const char *G_OP_TYPE_DEQUANTIZE = "dequantize";
extern const char *G_OP_TYPE_TANH = "tanh";
extern const char *G_OP_TYPE_FUSION_DECONV_RELU = "fusion_deconv_relu";
extern const char *G_OP_TYPE_FUSION_DECONV_ADD = "fusion_deconv_add";
extern const char *G_OP_TYPE_FUSION_DECONV_ADD_RELU = "fusion_deconv_add_relu";
std::unordered_map<
std::string, std::pair<std::vector<std::string>, std::vector<std::string>>>
......@@ -133,5 +135,7 @@ std::unordered_map<
{G_OP_TYPE_QUANTIZE, {{"X"}, {"Out", "OutScale"}}},
{G_OP_TYPE_DEQUANTIZE, {{"X", "Scale"}, {"Out"}}},
{G_OP_TYPE_TANH, {{"X"}, {"Out"}}},
{G_OP_TYPE_FUSION_DECONV_RELU, {{"Input"}, {"Out"}}}};
{G_OP_TYPE_FUSION_DECONV_RELU, {{"Input"}, {"Out"}}},
{G_OP_TYPE_FUSION_DECONV_ADD, {{"Input"}, {"Out"}}},
{G_OP_TYPE_FUSION_DECONV_ADD_RELU, {{"Input"}, {"Out"}}}};
} // namespace paddle_mobile
......@@ -142,6 +142,9 @@ extern const char *G_OP_TYPE_DEQUANTIZE;
extern const char *G_OP_TYPE_TANH;
extern const char *G_OP_TYPE_FUSION_DECONV_RELU;
extern const char *G_OP_TYPE_FUSION_DECONV_ADD;
extern const char *G_OP_TYPE_FUSION_DECONV_ADD_RELU;
extern std::unordered_map<
std::string, std::pair<std::vector<std::string>, std::vector<std::string>>>
op_input_output_key;
......
......@@ -281,6 +281,9 @@ std::shared_ptr<framework::Tensor> Executor<Dtype, P>::Predict(
clock_gettime(CLOCK_MONOTONIC, &ts);
profile[i].runBegin = (uint64_t)ts.tv_sec * 1e9 + ts.tv_nsec;
#endif
if (loddable_) {
ops[i]->InferShape();
}
// to Run
ops[i]->Run();
#ifdef PADDLE_MOBILE_PROFILE
......
......@@ -43,15 +43,21 @@ void Loader<Dtype, P>::InitMemoryFromProgram(
tensor->Resize(make_ddim(dim));
} else {
auto dim = var_desc->Tensor_desc().Dims();
PADDLE_MOBILE_ENFORCE(dim.size() > 0, "dim size is 0");
// PADDLE_MOBILE_ENFORCE(dim.size() > 0, "dim size is 0");
// dim[0] = 1;
for (auto &d : dim) {
if (d < 0) {
d *= -1;
if (dim.size() == 0) {
auto tensor = var->GetMutable<LoDTensor>();
framework::DDim dDim = {0};
tensor->Resize(dDim);
} else {
for (auto &d : dim) {
if (d < 0) {
d *= -1;
}
}
auto tensor = var->GetMutable<LoDTensor>();
tensor->Resize(make_ddim(dim));
}
auto tensor = var->GetMutable<LoDTensor>();
tensor->Resize(make_ddim(dim));
}
} else {
// TODO(codeWorm): some.
......
......@@ -405,9 +405,9 @@ Java_com_baidu_paddle_PML_predictLod(JNIEnv *env, jclass thiz, jlongArray buf) {
ANDROIDLOGE("predict nlp size %d", count);
result = env->NewLongArray(count);
env->SetLongArrayRegion(result, 0, count, vec_result->data<int64_t>());
env->ReleaseLongArrayElements(buf, ddim_ptr, 0);
return result;
}
......
......@@ -122,9 +122,12 @@ void PaddleMobile<Dtype, P>::Clear() {
executor_ = nullptr;
loader_ = nullptr;
}
template <typename Dtype, Precision P>
double PaddleMobile<Dtype, P>::GetPredictTime() {
double PaddleMobile<Dtype, P>::GetPredictTime() {}
#ifdef PADDLE_MOBILE_CPU
template <>
double PaddleMobile<CPU, Precision::FP32>::GetPredictTime() {
int m = 32;
int n = 224 * 224;
int k = 27;
......@@ -147,8 +150,8 @@ double PaddleMobile<Dtype, P>::GetPredictTime() {
}
paddle_mobile::operators::math::Gemm gemm;
auto time1 = paddle_mobile::time();
// gemm.Sgemm(m, n, k, static_cast<float>(1), a, lda, b, ldb,
// static_cast<float>(0), c, ldc, false, nullptr);
gemm.Sgemm(m, n, k, static_cast<float>(1), a, lda, b, ldb,
static_cast<float>(0), c, ldc, false, nullptr);
auto time2 = paddle_mobile::time();
double cost = paddle_mobile::time_diff(time1, time2);
paddle_mobile::memory::Free(a);
......@@ -156,6 +159,7 @@ double PaddleMobile<Dtype, P>::GetPredictTime() {
paddle_mobile::memory::Free(c);
return cost;
}
#endif
template <typename Dtype, Precision P>
PaddleMobile<Dtype, P>::~PaddleMobile() {
......
......@@ -55,6 +55,9 @@ namespace ops = paddle_mobile::operators;
#ifdef PADDLE_MOBILE_CPU
REGISTER_OPERATOR_CPU(box_coder, ops::BoxCoderOp);
#endif
#ifdef PADDLE_MOBILE_CL
REGISTER_OPERATOR_CL(box_coder, ops::BoxCoderOp);
#endif
#ifdef PADDLE_MOBILE_MALI_GPU
#endif
#ifdef PADDLE_MOBILE_FPGA
......
......@@ -66,6 +66,9 @@ namespace ops = paddle_mobile::operators;
#ifdef PADDLE_MOBILE_CPU
REGISTER_OPERATOR_CPU(concat, ops::ConcatOp);
#endif
#ifdef PADDLE_MOBILE_CL
REGISTER_OPERATOR_CL(concat, ops::ConcatOp);
#endif
#ifdef PADDLE_MOBILE_MALI_GPU
REGISTER_OPERATOR_MALI_GPU(concat, ops::ConcatOp);
#endif
......
......@@ -21,7 +21,13 @@ template <typename DeviceType, typename T>
void FeedOp<DeviceType, T>::InferShape() const {
auto out_dims = this->param_.Out()->dims();
out_dims[0] = this->param_.BatchSize();
this->param_.Out()->Resize(out_dims);
auto input_dims = this->param_.InputX()->dims();
DLOG << input_dims.size();
if (input_dims.size() == 4) {
this->param_.Out()->Resize(input_dims);
} else {
this->param_.Out()->Resize(out_dims);
}
}
} // namespace operators
......
......@@ -61,5 +61,7 @@ REGISTER_OPERATOR_MALI_GPU(fusion_conv_add, ops::FusionConvAddOp);
#ifdef PADDLE_MOBILE_CL
REGISTER_OPERATOR_CL(fusion_conv_add, ops::FusionConvAddOp);
#endif
#ifdef PADDLE_MOBILE_FPGA
REGISTER_OPERATOR_FPGA(fusion_conv_add, ops::FusionConvAddOp);
#endif
#endif
......@@ -29,8 +29,9 @@ namespace operators {
class FusionConvAddReluOpMatcher : public framework::FusionOpMatcher {
public:
FusionConvAddReluOpMatcher() {
// node_ = framework::Node(G_OP_TYPE_FUSION_CONV_ADD);
// node_ > std::make_shared<framework::Node>(G_OP_TYPE_RELU);
node_ = framework::Node(G_OP_TYPE_CONV);
node_ > std::make_shared<framework::Node>(G_OP_TYPE_ELEMENTWISE_ADD) >
std::make_shared<framework::Node>(G_OP_TYPE_RELU);
}
void FolderNodes(
......
......@@ -54,6 +54,9 @@ REGISTER_FUSION_MATCHER(fusion_conv_bn_relu, ops::FusionConvBNReluMatcher);
#ifdef PADDLE_MOBILE_CPU
REGISTER_OPERATOR_CPU(fusion_conv_bn_relu, ops::FusionConvBNReluOp);
#endif
#ifdef PADDLE_MOBILE_CL
REGISTER_OPERATOR_CL(fusion_conv_bn_relu, ops::FusionConvBNReluOp);
#endif
#ifdef PADDLE_MOBILE_FPGA
REGISTER_OPERATOR_FPGA(fusion_conv_bn_relu, ops::FusionConvBNReluOp);
#endif
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef FUSION_DECONVADD_OP
#include "operators/fusion_deconv_add_op.h"
namespace paddle_mobile {
namespace operators {}
} // namespace paddle_mobile
namespace ops = paddle_mobile::operators;
REGISTER_FUSION_MATCHER(fusion_deconv_add, ops::FusionDeconvAddMatcher);
#ifdef PADDLE_MOBILE_CPU
#endif
#ifdef PADDLE_MOBILE_MALI_GPU
#endif
#ifdef PADDLE_MOBILE_FPGA
REGISTER_OPERATOR_FPGA(fusion_deconv_add, ops::FusionDeconvAddOp);
#endif
#endif
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef FUSION_DECONVADD_OP
#pragma once
#include <string>
#include <vector>
#include "framework/operator.h"
#include "framework/program/program-optimize/fusion_op_register.h"
#include "operators/kernel/deconv_add_kernel.h"
namespace paddle_mobile {
namespace operators {
using std::string;
using std::vector;
class FusionDeconvAddMatcher : public framework::FusionOpMatcher {
public:
FusionDeconvAddMatcher() {
node_ = framework::Node(G_OP_TYPE_CONV_TRANSPOSE);
node_ > std::make_shared<framework::Node>(G_OP_TYPE_ELEMENTWISE_ADD);
}
void FolderNodes(
framework::Node *node,
std::vector<std::shared_ptr<framework::Node>> *removed_nodes) {
node->Folder(node_.Depth(), Type(),
{{G_OP_TYPE_ELEMENTWISE_ADD, {{"Y", "Y"}}}}, removed_nodes);
}
std::string Type() { return G_OP_TYPE_FUSION_DECONV_ADD; }
};
template <typename DeviceType, typename T>
class FusionDeconvAddOp : public framework::OperatorWithKernel<
DeviceType, FusionDeconvAddParam<DeviceType>,
operators::DeconvAddKernel<DeviceType, T>> {
public:
FusionDeconvAddOp(const string &type, const VariableNameMap &inputs,
const VariableNameMap &outputs,
const framework::AttributeMap &attrs,
std::shared_ptr<framework::Scope> scope)
: framework::OperatorWithKernel<
DeviceType, FusionDeconvAddParam<DeviceType>,
operators::DeconvAddKernel<DeviceType, T>>(type, inputs, outputs,
attrs, scope) {}
void InferShape() const {
auto input = this->param_.Input();
auto in_dims = input->dims();
auto filter = this->param_.Filter();
auto filter_dims = filter->dims();
std::vector<int> strides = this->param_.Strides();
std::vector<int> paddings = this->param_.Paddings();
std::vector<int> dilations = this->param_.Dilations();
int groups = this->param_.Groups();
PADDLE_MOBILE_ENFORCE(
in_dims.size() == 4 || in_dims.size() == 5,
"ConvTransposeOp intput should be 4-D or 5-D tensor.");
PADDLE_MOBILE_ENFORCE(
in_dims.size() == filter_dims.size(),
"ConvTransposeOp input dimension and filter dimension "
"should be the same.");
PADDLE_MOBILE_ENFORCE(
in_dims.size() - strides.size() == 2U,
"ConvTransposeOp input dimension and strides dimension should "
"be consistent.");
PADDLE_MOBILE_ENFORCE(paddings.size() == strides.size(),
"ConvTransposeOp paddings dimension and strides "
"dimension should be the same.");
PADDLE_MOBILE_ENFORCE(paddings.size() == dilations.size(),
"ConvTransposeOp paddings dimension and dilations "
"dimension should be the same.");
PADDLE_MOBILE_ENFORCE(
in_dims[1] == filter_dims[0],
"In ConvTransposeOp, The number of input channels should "
"be equal to the number of filter's channels.");
std::vector<int64_t> output_shape({in_dims[0], filter_dims[1] * groups});
for (size_t i = 0; i < strides.size(); ++i) {
auto filter_extent = dilations[i] * (filter_dims[i + 2] - 1) + 1;
output_shape.push_back((in_dims[i + 2] - 1) * strides[i] -
2 * paddings[i] + filter_extent);
}
this->param_.Output()->Resize(framework::make_ddim(output_shape));
}
protected:
};
} // namespace operators
} // namespace paddle_mobile
#endif // FUSION_DECONV_ADD_OP
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef FUSION_DECONVADDRELU_OP
#include "operators/fusion_deconv_add_relu_op.h"
namespace paddle_mobile {
namespace operators {}
} // namespace paddle_mobile
namespace ops = paddle_mobile::operators;
REGISTER_FUSION_MATCHER(fusion_deconv_add_relu,
ops::FusionDeconvAddReluMatcher);
#ifdef PADDLE_MOBILE_CPU
#endif
#ifdef PADDLE_MOBILE_MALI_GPU
#endif
#ifdef PADDLE_MOBILE_FPGA
REGISTER_OPERATOR_FPGA(fusion_deconv_add_relu, ops::FusionDeconvAddReluOp);
#endif
#endif
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef FUSION_DECONVADDRELU_OP
#pragma once
#include <string>
#include <vector>
#include "framework/operator.h"
#include "framework/program/program-optimize/fusion_op_register.h"
#include "operators/kernel/deconv_add_relu_kernel.h"
namespace paddle_mobile {
namespace operators {
using std::string;
using std::vector;
class FusionDeconvAddReluMatcher : public framework::FusionOpMatcher {
public:
FusionDeconvAddReluMatcher() {
node_ = framework::Node(G_OP_TYPE_CONV_TRANSPOSE);
node_ > std::make_shared<framework::Node>(G_OP_TYPE_ELEMENTWISE_ADD) >
std::make_shared<framework::Node>(G_OP_TYPE_RELU);
}
void FolderNodes(
framework::Node *node,
std::vector<std::shared_ptr<framework::Node>> *removed_nodes) {
node->Folder(node_.Depth(), Type(),
{{G_OP_TYPE_ELEMENTWISE_ADD, {{"Y", "Y"}}}}, removed_nodes);
}
std::string Type() { return G_OP_TYPE_FUSION_DECONV_ADD_RELU; }
};
template <typename DeviceType, typename T>
class FusionDeconvAddReluOp
: public framework::OperatorWithKernel<
DeviceType, FusionDeconvAddReluParam<DeviceType>,
operators::DeconvAddReluKernel<DeviceType, T>> {
public:
FusionDeconvAddReluOp(const string &type, const VariableNameMap &inputs,
const VariableNameMap &outputs,
const framework::AttributeMap &attrs,
std::shared_ptr<framework::Scope> scope)
: framework::OperatorWithKernel<
DeviceType, FusionDeconvAddReluParam<DeviceType>,
operators::DeconvAddReluKernel<DeviceType, T>>(
type, inputs, outputs, attrs, scope) {}
void InferShape() const {
auto input = this->param_.Input();
auto in_dims = input->dims();
auto filter = this->param_.Filter();
auto filter_dims = filter->dims();
std::vector<int> strides = this->param_.Strides();
std::vector<int> paddings = this->param_.Paddings();
std::vector<int> dilations = this->param_.Dilations();
int groups = this->param_.Groups();
PADDLE_MOBILE_ENFORCE(
in_dims.size() == 4 || in_dims.size() == 5,
"ConvTransposeOp intput should be 4-D or 5-D tensor.");
PADDLE_MOBILE_ENFORCE(
in_dims.size() == filter_dims.size(),
"ConvTransposeOp input dimension and filter dimension "
"should be the same.");
PADDLE_MOBILE_ENFORCE(
in_dims.size() - strides.size() == 2U,
"ConvTransposeOp input dimension and strides dimension should "
"be consistent.");
PADDLE_MOBILE_ENFORCE(paddings.size() == strides.size(),
"ConvTransposeOp paddings dimension and strides "
"dimension should be the same.");
PADDLE_MOBILE_ENFORCE(paddings.size() == dilations.size(),
"ConvTransposeOp paddings dimension and dilations "
"dimension should be the same.");
PADDLE_MOBILE_ENFORCE(
in_dims[1] == filter_dims[0],
"In ConvTransposeOp, The number of input channels should "
"be equal to the number of filter's channels.");
std::vector<int64_t> output_shape({in_dims[0], filter_dims[1] * groups});
for (size_t i = 0; i < strides.size(); ++i) {
auto filter_extent = dilations[i] * (filter_dims[i + 2] - 1) + 1;
output_shape.push_back((in_dims[i + 2] - 1) * strides[i] -
2 * paddings[i] + filter_extent);
}
this->param_.Output()->Resize(framework::make_ddim(output_shape));
}
protected:
};
} // namespace operators
} // namespace paddle_mobile
#endif // FUSION_DECONV_ADD_RELU_OP
......@@ -54,6 +54,9 @@ REGISTER_FUSION_MATCHER(fusion_dwconv_bn_relu, ops::FusionDWConvBNReluMatcher);
#ifdef PADDLE_MOBILE_CPU
REGISTER_OPERATOR_CPU(fusion_dwconv_bn_relu, ops::FusionDWConvBNReluOp);
#endif
#ifdef PADDLE_MOBILE_CL
REGISTER_OPERATOR_CL(fusion_dwconv_bn_relu, ops::FusionDWConvBNReluOp);
#endif
#ifdef PADDLE_MOBILE_FPGA
#endif
......
......@@ -115,6 +115,7 @@ void ConvAddBasic(const FusionConvAddParam<CPU> &param) {
template <typename P>
void ConvAddCompute(const FusionConvAddParam<CPU> &param) {
param.Output()->mutable_data<float>();
if (param.Groups() == param.Input()->dims()[1] &&
param.Input()->dims()[1] == param.Output()->dims()[1] &&
param.Filter()->dims()[2] == param.Filter()->dims()[3] &&
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef BOXCODER_OP
#include "operators/kernel/box_coder_kernel.h"
namespace paddle_mobile {
namespace operators {
template <>
bool BoxCoderKernel<GPU_CL, float>::Init(BoxCoderParam<GPU_CL> *param) {
return true;
}
template <>
void BoxCoderKernel<GPU_CL, float>::Compute(
const BoxCoderParam<GPU_CL> &param) {}
} // namespace operators
} // namespace paddle_mobile
#endif
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#define BATCH_NORM
#define RELU
#include "conv_kernel.inc.cl"
......@@ -20,7 +20,8 @@ __kernel void fetch(__private const int in_height,
__global float* out,
__private const int size_ch,
__private const int size_block,
__private const int size_batch) {
__private const int size_batch,
__private const int C) {
const int in_c = get_global_id(0);
const int in_w = get_global_id(1);
const int in_nh = get_global_id(2);
......@@ -35,9 +36,17 @@ __kernel void fetch(__private const int in_height,
const int index = in_n * size_batch + in_c * size_block + in_h * in_width + in_w;
out[index] = convert_float(in.x);
out[index + size_ch] = convert_float(in.y);
if(C - 4 * in_c>=2){
out[index + size_ch] = convert_float(in.y);
}
if(C - 4 * in_c>=3){
out[index + size_ch * 2] = convert_float(in.z);
out[index + size_ch * 3] = convert_float(in.w);
}
if(C - 4 * in_c>=4){
out[index + size_ch * 3] = convert_float(in.w);
}
}
__kernel void fetch_2d(__private const int in_height,
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
__kernel void prior_box(__private const int global_size_dim0,
__private const int global_size_dim1,
__private const int global_size_dim2,
__global float *box_width,
__global float *box_height,
__write_only image2d_t output_image,
__private const float step_width,
__private const float step_height,
__private const float offset,
__private const int img_width,
__private const int img_height,
__private const int num_priors,
__private const int C){
const int out_c = get_global_id(0);
const int out_nh = get_global_id(1);
const int out_n = out_nh/num_priors;
const int out_h = out_nh%num_priors;
if (out_c >= global_size_dim0 ||out_nh >= global_size_dim2) {
return;
}
const sampler_t sampler = CLK_NORMALIZED_COORDS_TRUE |
CLK_ADDRESS_CLAMP |
CLK_FILTER_NEAREST;
int2 output_pos;
output_pos.x = out_c * 4;
output_pos.y = out_nh;
float center_x0 = (offset + out_c * 4) * step_width;
float center_x1 = (offset + out_c * 4 + 1) * step_width;
float center_x2 = (offset + out_c * 4 + 2) * step_width;
float center_x3 = (offset + out_c * 4 + 3) * step_width;
float center_y = (out_n + offset) * step_height;
half4 output[4];
output[0].x = convert_half((center_x0 - box_width[out_h]) / img_width);
output[1].x = convert_half((center_y - box_height[out_h]) / img_height);
output[2].x = convert_half((center_x0 + box_width[out_h]) / img_width);
output[3].x = convert_half((center_y + box_height[out_h]) / img_height);
if(C - 4 * out_c>=2){
output[0].y = convert_half((center_x1 - box_width[out_h]) / img_width);
output[1].y = convert_half((center_y - box_height[out_h]) / img_height);
output[2].y = convert_half((center_x1 + box_width[out_h]) / img_width);
output[3].y = convert_half((center_y + box_height[out_h]) / img_height);
}else{
output[0].y = 0.0f;
output[1].y = 0.0f;
output[2].y = 0.0f;
output[3].y = 0.0f;
}
if(C - 4 * out_c>=3){
output[0].z = convert_half((center_x2 - box_width[out_h]) / img_width);
output[1].z = convert_half((center_y - box_height[out_h]) / img_height);
output[2].z = convert_half((center_x2 + box_width[out_h]) / img_width);
output[3].z = convert_half((center_y + box_height[out_h]) / img_height);
}else{
output[0].z = 0.0f;
output[1].z = 0.0f;
output[2].z = 0.0f;
output[3].z = 0.0f;
}
if(C - 4 * out_c>=4){
output[0].w = convert_half((center_x3 - box_width[out_h]) / img_width);
output[1].w = convert_half((center_y - box_height[out_h]) / img_height);
output[2].w = convert_half((center_x3 + box_width[out_h]) / img_width);
output[3].w = convert_half((center_y + box_height[out_h]) / img_height);
}else{
output[0].z = 0.0f;
output[1].z = 0.0f;
output[2].z = 0.0f;
output[3].z = 0.0f;
}
output[0] = min(max((half4)(0.0f, 0.0f, 0.0f, 0.0f), output[0]),(half4)(1.0f, 1.0f, 1.0f, 1.0f));
output[1] = min(max((half4)(0.0f, 0.0f, 0.0f, 0.0f), output[1]),(half4)(1.0f, 1.0f, 1.0f, 1.0f));
output[2] = min(max((half4)(0.0f, 0.0f, 0.0f, 0.0f), output[2]),(half4)(1.0f, 1.0f, 1.0f, 1.0f));
output[3] = min(max((half4)(0.0f, 0.0f, 0.0f, 0.0f), output[3]),(half4)(1.0f, 1.0f, 1.0f, 1.0f));
write_imageh(output_image, (int2)(output_pos.x + 1, output_pos.y), output[0]);
write_imageh(output_image, (int2)(output_pos.x + 2, output_pos.y), output[1]);
write_imageh(output_image, (int2)(output_pos.x + 3, output_pos.y), output[2]);
write_imageh(output_image, (int2)(output_pos.x + 4, output_pos.y), output[3]);
}
\ No newline at end of file
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef CONCAT_OP
#include "operators/kernel/concat_kernel.h"
namespace paddle_mobile {
namespace operators {
template <>
bool ConcatKernel<GPU_CL, float>::Init(ConcatParam<GPU_CL> *param) {
return true;
}
template <>
void ConcatKernel<GPU_CL, float>::Compute(const ConcatParam<GPU_CL> &param) {}
} // namespace operators
} // namespace paddle_mobile
#endif
......@@ -68,10 +68,10 @@ void ConvAddKernel<GPU_CL, float>::Compute(
int nh = default_work_size[2];
auto input = param.Input()->GetCLImage();
auto filter = param.Filter()->GetCLImage();
DLOG << "---yangfei30---";
DLOG << *param.Filter();
DLOG << param.Paddings();
auto biase = param.Bias()->GetCLImage();
param.Output()->InitEmptyImage(cl_helper_.CLContext(),
cl_helper_.CLCommandQueue(),
param.Output()->dims());
auto output = param.Output()->GetCLImage();
int stride = param.Strides()[0];
int offset = param.Offset();
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef FUSION_CONVBNRELU_OP
#include "operators/kernel/conv_bn_relu_kernel.h"
namespace paddle_mobile {
namespace operators {
template <>
bool ConvBNReluKernel<GPU_CL, float>::Init(
FusionConvBNReluParam<GPU_CL> *param) {
PADDLE_MOBILE_ENFORCE(
param->Filter()->dims()[2] == param->Filter()->dims()[3] &&
param->Paddings()[0] == param->Paddings()[1],
"need equal");
const framework::CLImage *mean = param->InputMean();
const framework::CLImage *variance = param->InputVariance();
const framework::CLImage *scale = param->InputScale();
const framework::CLImage *bias = param->InputBias();
const float epsilon = param->Epsilon();
const int C = mean->numel();
auto mean_ptr = mean->data<float>();
auto variance_ptr = variance->data<float>();
auto scale_ptr = scale->data<float>();
auto bias_ptr = bias->data<float>();
float inv_std_ptr[C];
for (int i = 0; i < C; i++) {
inv_std_ptr[i] =
1 / static_cast<float>(pow((variance_ptr[i] + epsilon), 0.5));
}
float *new_scale_ptr = new float[C];
float *new_bias_ptr = new float[C];
for (int i = 0; i < C; i++) {
new_scale_ptr[i] = inv_std_ptr[i] * scale_ptr[i];
new_bias_ptr[i] = bias_ptr[i] - mean_ptr[i] * inv_std_ptr[i] * scale_ptr[i];
}
framework::CLImage *new_scale = new framework::CLImage();
// for (int j = 0; j < C; ++j) {
// DLOG << " new scale - " << j << new_scale_ptr[j];
// }
//
// for (int j = 0; j < C; ++j) {
// DLOG << " new bias - " << j << new_bias_ptr[j];
// }
new_scale->SetTensorData(new_scale_ptr, variance->dims());
new_scale->InitCLImage(this->cl_helper_.CLContext(),
cl_helper_.CLCommandQueue());
// DLOG << " climage - y bias: " << *(param->Bias());
//
// DLOG << " climage - new scale: " << *new_scale;
framework::CLImage *new_bias = new framework::CLImage();
new_bias->SetTensorData(new_bias_ptr, variance->dims());
new_bias->InitCLImage(this->cl_helper_.CLContext(),
cl_helper_.CLCommandQueue());
// DLOG << " climage - new bias: " << *new_bias;
//
// DLOG << " climage - filter: " << *(param->Filter());
param->SetNewScale(new_scale);
param->SetNewBias(new_bias);
delete[](new_scale_ptr);
delete[](new_bias_ptr);
PADDLE_MOBILE_ENFORCE(
param->Filter()->dims()[2] == param->Filter()->dims()[3] &&
param->Paddings()[0] == param->Paddings()[1],
"need equal");
int offset = static_cast<int>(param->Filter()->dims()[2]) / 2 -
static_cast<int>(param->Paddings()[1]);
param->SetOffset(offset);
if (param->Filter()->dims()[2] == 1 && param->Filter()->dims()[3] == 1) {
param->Filter()->InitNImage(cl_helper_.CLContext(),
cl_helper_.CLCommandQueue());
this->cl_helper_.AddKernel("conv_1x1", "conv_bn_relu_kernel.cl");
DLOG << " conv bn relu conv 1x1";
} else if (param->Filter()->dims()[1] == 1 &&
param->Input()->dims()[1] == param->Output()->dims()[1] &&
param->Filter()->dims()[2] == 3) {
param->Filter()->InitDWImage(cl_helper_.CLContext(),
cl_helper_.CLCommandQueue());
this->cl_helper_.AddKernel("depth_conv_3x3", "conv_bn_relu_kernel.cl");
DLOG << " conv bn relu depth_conv_3x3";
} else if (param->Filter()->dims()[2] == 3 &&
param->Filter()->dims()[3] == 3) {
param->Filter()->InitCLImage(cl_helper_.CLContext(),
cl_helper_.CLCommandQueue());
this->cl_helper_.AddKernel("conv_3x3", "conv_bn_relu_kernel.cl");
DLOG << " conv bn relu conv_3x3";
} else {
PADDLE_MOBILE_THROW_EXCEPTION(" not support ");
}
return true;
}
template <>
void ConvBNReluKernel<GPU_CL, float>::Compute(
const FusionConvBNReluParam<GPU_CL> &param) {
auto kernel = this->cl_helper_.KernelAt(0);
auto default_work_size = this->cl_helper_.DefaultWorkSize(*param.Output());
int c_block = default_work_size[0];
int w = default_work_size[1];
int nh = default_work_size[2];
auto input = param.Input()->GetCLImage();
auto filter = param.Filter()->GetCLImage();
auto new_scale = param.NewScale()->GetCLImage();
auto new_bias = param.NewBias()->GetCLImage();
auto output = param.Output()->GetCLImage();
int stride = param.Strides()[0];
int offset = param.Offset();
int input_c = reinterpret_cast<framework::CLImageConverterFolder *>(
param.Input()->Converter())
->GetCBlock();
int dilation = param.Dilations()[0];
int input_width = param.Input()->dims()[3];
int input_height = param.Input()->dims()[2];
int output_width = param.Output()->dims()[3];
int output_height = param.Output()->dims()[2];
cl_int status;
status = clSetKernelArg(kernel, 0, sizeof(int), &c_block);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 1, sizeof(int), &w);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 2, sizeof(int), &nh);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 3, sizeof(cl_mem), &input);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 4, sizeof(cl_mem), &filter);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 5, sizeof(cl_mem), &new_scale);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 6, sizeof(cl_mem), &new_bias);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 7, sizeof(cl_mem), &output);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 8, sizeof(int), &stride);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 9, sizeof(int), &offset);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 10, sizeof(int), &input_c);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 11, sizeof(int), &dilation);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 12, sizeof(int), &input_width);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 13, sizeof(int), &input_height);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 14, sizeof(int), &output_width);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 15, sizeof(int), &output_height);
CL_CHECK_ERRORS(status);
status = clEnqueueNDRangeKernel(
this->cl_helper_.CLCommandQueue(), kernel, default_work_size.size(), NULL,
default_work_size.data(), NULL, 0, NULL, NULL);
CL_CHECK_ERRORS(status);
}
template class ConvBNReluKernel<GPU_CL, float>;
} // namespace operators
} // namespace paddle_mobile
#endif
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef FUSION_DWCONVBNRELU_OP
#include "operators/kernel/dwconv_bn_relu_kernel.h"
namespace paddle_mobile {
namespace operators {
template <>
bool DWConvBNReluKernel<GPU_CL, float>::Init(
FusionDWConvBNReluParam<GPU_CL> *param) {
PADDLE_MOBILE_ENFORCE(
param->Filter()->dims()[2] == param->Filter()->dims()[3] &&
param->Paddings()[0] == param->Paddings()[1],
"need equal");
const framework::CLImage *mean = param->InputMean();
const framework::CLImage *variance = param->InputVariance();
const framework::CLImage *scale = param->InputScale();
const framework::CLImage *bias = param->InputBias();
const float epsilon = param->Epsilon();
const int C = mean->numel();
auto mean_ptr = mean->data<float>();
auto variance_ptr = variance->data<float>();
auto scale_ptr = scale->data<float>();
auto bias_ptr = bias->data<float>();
float inv_std_ptr[C];
for (int i = 0; i < C; i++) {
inv_std_ptr[i] =
1 / static_cast<float>(pow((variance_ptr[i] + epsilon), 0.5));
}
float *new_scale_ptr = new float[C];
float *new_bias_ptr = new float[C];
for (int i = 0; i < C; i++) {
new_scale_ptr[i] = inv_std_ptr[i] * scale_ptr[i];
new_bias_ptr[i] = bias_ptr[i] - mean_ptr[i] * inv_std_ptr[i] * scale_ptr[i];
}
framework::CLImage *new_scale = new framework::CLImage();
new_scale->SetTensorData(new_scale_ptr, variance->dims());
new_scale->InitCLImage(this->cl_helper_.CLContext(),
cl_helper_.CLCommandQueue());
framework::CLImage *new_bias = new framework::CLImage();
new_bias->SetTensorData(new_bias_ptr, variance->dims());
new_bias->InitCLImage(this->cl_helper_.CLContext(),
cl_helper_.CLCommandQueue());
param->SetNewScale(new_scale);
param->SetNewBias(new_bias);
delete[](new_scale_ptr);
delete[](new_bias_ptr);
PADDLE_MOBILE_ENFORCE(
param->Filter()->dims()[2] == param->Filter()->dims()[3] &&
param->Paddings()[0] == param->Paddings()[1],
"need equal");
int offset = static_cast<int>(param->Filter()->dims()[2]) / 2 -
static_cast<int>(param->Paddings()[1]);
param->SetOffset(offset);
param->Filter()->InitDWImage(cl_helper_.CLContext(),
cl_helper_.CLCommandQueue());
this->cl_helper_.AddKernel("depth_conv_3x3", "conv_bn_relu_kernel.cl");
DLOG << " conv bn relu depth_conv_3x3";
return true;
}
template <>
void DWConvBNReluKernel<GPU_CL, float>::Compute(
const FusionDWConvBNReluParam<GPU_CL> &param) {
auto kernel = this->cl_helper_.KernelAt(0);
auto default_work_size = this->cl_helper_.DefaultWorkSize(*param.Output());
int c_block = default_work_size[0];
int w = default_work_size[1];
int nh = default_work_size[2];
auto input = param.Input()->GetCLImage();
auto filter = param.Filter()->GetCLImage();
auto new_scale = param.NewScale()->GetCLImage();
auto new_bias = param.NewBias()->GetCLImage();
auto output = param.Output()->GetCLImage();
int stride = param.Strides()[0];
int offset = param.Offset();
int input_c = reinterpret_cast<framework::CLImageConverterFolder *>(
param.Input()->Converter())
->GetCBlock();
int dilation = param.Dilations()[0];
int input_width = param.Input()->dims()[3];
int input_height = param.Input()->dims()[2];
int output_width = param.Output()->dims()[3];
int output_height = param.Output()->dims()[2];
cl_int status;
status = clSetKernelArg(kernel, 0, sizeof(int), &c_block);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 1, sizeof(int), &w);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 2, sizeof(int), &nh);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 3, sizeof(cl_mem), &input);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 4, sizeof(cl_mem), &filter);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 5, sizeof(cl_mem), &new_scale);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 6, sizeof(cl_mem), &new_bias);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 7, sizeof(cl_mem), &output);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 8, sizeof(int), &stride);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 9, sizeof(int), &offset);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 10, sizeof(int), &input_c);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 11, sizeof(int), &dilation);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 12, sizeof(int), &input_width);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 13, sizeof(int), &input_height);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 14, sizeof(int), &output_width);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 15, sizeof(int), &output_height);
CL_CHECK_ERRORS(status);
status = clEnqueueNDRangeKernel(
this->cl_helper_.CLCommandQueue(), kernel, default_work_size.size(), NULL,
default_work_size.data(), NULL, 0, NULL, NULL);
CL_CHECK_ERRORS(status);
}
template class DWConvBNReluKernel<GPU_CL, float>;
} // namespace operators
} // namespace paddle_mobile
#endif
......@@ -28,6 +28,8 @@ template <>
void FeedKernel<GPU_CL, float>::Compute(const FeedParam<GPU_CL> &param) {
auto kernel = this->cl_helper_.KernelAt(0);
cl_int status;
param.Out()->InitEmptyImage(cl_helper_.CLContext(),
cl_helper_.CLCommandQueue(), param.Out()->dims());
auto output = param.Out();
const Tensor *input = param.InputX();
// DLOG << *input;
......
......@@ -27,8 +27,6 @@ bool FetchKernel<GPU_CL, float>::Init(FetchParam<GPU_CL> *param) {
} else {
this->cl_helper_.AddKernel("fetch", "fetch_kernel.cl");
}
auto *out = param->Out();
out->mutable_data<float>();
return true;
}
......@@ -39,7 +37,7 @@ void FetchKernel<GPU_CL, float>::Compute(const FetchParam<GPU_CL> &param) {
auto input = param.InputX()->GetCLImage();
auto *out = param.Out();
out->mutable_data<float>();
const auto &dim = param.InputX()->dims();
size_t new_dims[] = {1, 1, 1, 1};
......@@ -70,9 +68,11 @@ void FetchKernel<GPU_CL, float>::Compute(const FetchParam<GPU_CL> &param) {
int size_ch = in_height * in_width;
int size_block = size_ch * 4;
int size_batch = size_ch * C;
int out_c = new_dims[1];
clSetKernelArg(kernel, 4, sizeof(int), &size_ch);
clSetKernelArg(kernel, 5, sizeof(int), &size_block);
clSetKernelArg(kernel, 6, sizeof(int), &size_batch);
clSetKernelArg(kernel, 7, sizeof(int), &out_c);
}
// cl_event wait_event = param.InpdutX()->GetClEvent();
......@@ -93,6 +93,8 @@ void FetchKernel<GPU_CL, float>::Compute(const FetchParam<GPU_CL> &param) {
// << "ms" << std::endl;
memcpy(out->data<float>(), out_cl_tensor.Data<float>(), out->memory_size());
DLOG << *param.InputX();
DLOG << *out;
}
template class FetchKernel<GPU_CL, float>;
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef MULTICLASSNMS_OP
#include "operators/kernel/multiclass_nms_kernel.h"
namespace paddle_mobile {
namespace operators {
template <>
bool MultiClassNMSKernel<GPU_CL, float>::Init(
MultiClassNMSParam<GPU_CL> *param) {
return true;
}
template <>
void MultiClassNMSKernel<GPU_CL, float>::Compute(
const MultiClassNMSParam<GPU_CL> &param) {}
} // namespace operators
} // namespace paddle_mobile
#endif
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef PRIORBOX_OP
#include "operators/kernel/prior_box_kernel.h"
#include "framework/cl/cl_tensor.h"
namespace paddle_mobile {
namespace operators {
template <>
bool PriorBoxKernel<GPU_CL, float>::Init(PriorBoxParam<GPU_CL> *param) {
this->cl_helper_.AddKernel("prior_box", "prior_box_kernel.cl");
return true;
}
template <>
void PriorBoxKernel<GPU_CL, float>::Compute(
const PriorBoxParam<GPU_CL> &param) {
const auto *input_ = param.Input();
const auto &input_dims = input_->dims();
const auto &input_image_dims = param.InputImage()->dims();
const auto &min_sizes = param.MinSizes();
const auto &max_sizes = param.MaxSizes();
const auto &variances = param.Variances();
const auto &input_aspect_ratio = param.AspectRatios();
const bool &flip = param.Flip();
const bool &clip = param.Clip();
const float &step_w = param.StepW();
const float &step_h = param.StepH();
const float &offset = param.Offset();
const int C = param.OutputBoxes()->dims()[1];
auto output_boxes = param.OutputBoxes()->GetCLImage();
auto output_variances = param.OutputVariances()->GetCLImage();
std::vector<float> aspect_ratios;
ExpandAspectRatios(input_aspect_ratio, flip, &aspect_ratios);
auto img_width = input_image_dims[3];
auto img_height = input_image_dims[2];
auto feature_width = input_dims[3];
auto feature_height = input_dims[2];
float step_width, step_height;
/// 300 / 19
if (step_w == 0 || step_h == 0) {
step_width = static_cast<float>(img_width) / feature_width;
step_height = static_cast<float>(img_height) / feature_height;
} else {
step_width = step_w;
step_height = step_h;
}
int num_priors = aspect_ratios.size() * min_sizes.size();
if (!max_sizes.empty()) {
num_priors += max_sizes.size();
}
float *box_width = static_cast<float *>(
paddle_mobile::memory::Alloc(sizeof(float) * num_priors));
float *box_height = static_cast<float *>(
paddle_mobile::memory::Alloc(sizeof(float) * num_priors));
int idx = 0;
for (size_t s = 0; s < min_sizes.size(); ++s) {
auto min_size = min_sizes[s];
if (param.MinMaxAspectRatiosOrder()) {
box_width[idx] = box_height[idx] = min_size / 2.;
idx++;
if (max_sizes.size() > 0) {
auto max_size = max_sizes[s];
box_width[idx] = box_height[idx] = sqrt(min_size * max_size) / 2.;
idx++;
}
for (float ar : aspect_ratios) {
if (fabs(ar - 1.) < 1e-6) {
continue;
}
box_width[idx] = min_size * sqrt(ar) / 2.;
box_height[idx] = min_size / sqrt(ar) / 2.;
idx++;
}
} else {
for (float ar : aspect_ratios) {
box_width[idx] = min_size * sqrt(ar) / 2.;
box_height[idx] = min_size / sqrt(ar) / 2.;
idx++;
}
if (!max_sizes.empty()) {
auto max_size = max_sizes[s];
box_width[idx] = box_height[idx] = sqrt(min_size * max_size) / 2.;
idx++;
}
}
}
cl_int status;
auto kernel = this->cl_helper_.KernelAt(0);
auto default_work_size =
this->cl_helper_.DefaultWorkSize(*param.OutputBoxes());
int c_block = default_work_size[0];
int w = default_work_size[1];
int nh = default_work_size[2];
std::vector<int64_t> box_shape({1, 1, 1, num_priors});
framework::DDim ddim = framework::make_ddim(box_shape);
framework::CLTensor box_width_cl_tensor(this->cl_helper_.CLContext(),
this->cl_helper_.CLCommandQueue());
box_width_cl_tensor.Resize(ddim);
cl_mem box_width_Buffer =
box_width_cl_tensor.mutable_with_data<float>(box_width);
framework::CLTensor box_height_cl_tensor(this->cl_helper_.CLContext(),
this->cl_helper_.CLCommandQueue());
box_height_cl_tensor.Resize(ddim);
cl_mem box_height_Buffer =
box_height_cl_tensor.mutable_with_data<float>(box_height);
DLOG << "c_block:" << c_block;
DLOG << "w:" << w;
DLOG << "nh:" << nh;
DLOG << "step_width:" << step_width;
DLOG << "step_height:" << step_height;
DLOG << "offset:" << offset;
DLOG << "img_width:" << img_width;
DLOG << "img_height:" << img_height;
DLOG << "num_priors:" << num_priors;
DLOG << "C:" << C;
status = clSetKernelArg(kernel, 0, sizeof(int), &c_block);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 1, sizeof(int), &w);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 2, sizeof(int), &nh);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 3, sizeof(cl_mem), &box_width_Buffer);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 4, sizeof(cl_mem), &box_height_Buffer);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 5, sizeof(cl_mem), &output_boxes);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 6, sizeof(float), &step_width);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 7, sizeof(float), &step_height);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 8, sizeof(float), &offset);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 9, sizeof(int), &img_width);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 10, sizeof(int), &img_height);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 11, sizeof(int), &num_priors);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 12, sizeof(int), &C);
CL_CHECK_ERRORS(status);
size_t global_work_size[2] = {c_block, nh};
status = clEnqueueNDRangeKernel(this->cl_helper_.CLCommandQueue(), kernel, 2,
NULL, global_work_size, NULL, 0, NULL, NULL);
CL_CHECK_ERRORS(status);
paddle_mobile::memory::Free(box_width);
paddle_mobile::memory::Free(box_height);
}
template class PriorBoxKernel<GPU_CL, float>;
} // namespace operators
} // namespace paddle_mobile
#endif
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef TRANSPOSE_OP
#include "operators/kernel/transpose_kernel.h"
namespace paddle_mobile {
namespace operators {
template <>
bool TransposeKernel<GPU_CL, float>::Init(TransposeParam<GPU_CL> *param) {
return true;
}
template <>
void TransposeKernel<GPU_CL, float>::Compute(
const TransposeParam<GPU_CL> &param) {}
} // namespace operators
} // namespace paddle_mobile
#endif
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef FUSION_DECONVADD_OP
#pragma once
#include "framework/operator.h"
#include "operators/op_param.h"
namespace paddle_mobile {
namespace operators {
using framework::OpKernelBase;
template <typename DeviceType, typename T>
class DeconvAddKernel
: public OpKernelBase<DeviceType, FusionDeconvAddParam<DeviceType>> {
public:
void Compute(const FusionDeconvAddParam<DeviceType> &param);
bool Init(FusionDeconvAddParam<DeviceType> *param);
};
} // namespace operators
} // namespace paddle_mobile
#endif
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef FUSION_DECONVADDRELU_OP
#pragma once
#include "framework/operator.h"
#include "operators/op_param.h"
namespace paddle_mobile {
namespace operators {
using framework::OpKernelBase;
template <typename DeviceType, typename T>
class DeconvAddReluKernel
: public OpKernelBase<DeviceType, FusionDeconvAddReluParam<DeviceType>> {
public:
void Compute(const FusionDeconvAddReluParam<DeviceType> &param);
bool Init(FusionDeconvAddReluParam<DeviceType> *param);
};
} // namespace operators
} // namespace paddle_mobile
#endif
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef FUSION_CONVADD_OP
#include "operators/kernel/conv_add_kernel.h"
namespace paddle_mobile {
namespace operators {
template <>
bool ConvAddKernel<FPGA, float>::Init(FusionConvAddParam<FPGA> *param) {
bool relu_enabled = false;
auto input = const_cast<Tensor *>(param->Input());
const Tensor *bias = param->Bias();
auto bias_ptr = bias->data<float>();
auto filter = const_cast<Tensor *>(param->Filter());
auto out = param->Output();
PADDLE_MOBILE_ENFORCE(out->dims()[1] == bias->dims()[0],
"Output channel should be equal to bias number");
int channel = out->dims()[1];
auto bs_ptr =
(float *)fpga::fpga_malloc(2 * channel * sizeof(float)); // NOLINT
for (int i = 0; i < channel; i++) {
bs_ptr[i + channel] = 1;
bs_ptr[i] = bias_ptr[i];
}
fpga::format_conv_data(filter, out, bs_ptr, param->Groups());
fpga::SplitConvArgs conv_arg = {0};
fpga::fill_split_arg(&conv_arg, input, out, filter, relu_enabled,
param->Groups(), param->Strides()[0],
param->Strides()[1], param->Paddings()[0],
param->Paddings()[1], bs_ptr);
param->SetFpgaArgs(conv_arg);
return true;
}
template <>
void ConvAddKernel<FPGA, float>::Compute(
const FusionConvAddParam<FPGA> &param) {
fpga::ComputeFpgaConv(param.FpgaArgs());
}
} // namespace operators
} // namespace paddle_mobile
#endif
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef FUSION_DECONVADD_OP
#include "operators/kernel/deconv_add_kernel.h"
#include "framework/operator.h"
#include "operators/op_param.h"
namespace paddle_mobile {
namespace operators {
template <>
bool DeconvAddKernel<FPGA, float>::Init(FusionDeconvAddParam<FPGA> *param) {
return true;
}
template <>
void DeconvAddKernel<FPGA, float>::Compute(
const FusionDeconvAddParam<FPGA> &param) {}
} // namespace operators
} // namespace paddle_mobile
#endif
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef FUSION_DECONVADDRELU_OP
#include "operators/kernel/deconv_add_relu_kernel.h"
#include "framework/operator.h"
#include "operators/op_param.h"
namespace paddle_mobile {
namespace operators {
template <>
bool DeconvAddReluKernel<FPGA, float>::Init(
FusionDeconvAddReluParam<FPGA> *param) {
return true;
}
template <>
void DeconvAddReluKernel<FPGA, float>::Compute(
const FusionDeconvAddReluParam<FPGA> &param) {}
} // namespace operators
} // namespace paddle_mobile
#endif
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef SPLIT_OP
#include "operators/kernel/split_kernel.h"
namespace paddle_mobile {
namespace operators {
template <>
bool SplitKernel<FPGA, float>::Init(SplitParam<FPGA>* param) {
return true;
}
template <>
void SplitKernel<FPGA, float>::Compute(const SplitParam<FPGA>& param) {}
} // namespace operators
} // namespace paddle_mobile
#endif
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef TRANSPOSE2_OP
#include "operators/kernel/transpose2_kernel.h"
#include "operators/kernel/central-arm-func/transpose2_arm_func.h"
namespace paddle_mobile {
namespace operators {
template <>
bool Transpose2Kernel<FPGA, float>::Init(Transpose2Param<FPGA> *param) {
return true;
}
template <>
void Transpose2Kernel<FPGA, float>::Compute(
const Transpose2Param<FPGA> &param) {
// Transpose2Compute<float>(param);
}
} // namespace operators
} // namespace paddle_mobile
#endif
......@@ -43,5 +43,8 @@ namespace ops = paddle_mobile::operators;
#ifdef PADDLE_MOBILE_CPU
REGISTER_OPERATOR_CPU(multiclass_nms, ops::MultiClassNMSOp);
#endif
#ifdef PADDLE_MOBILE_CL
REGISTER_OPERATOR_CL(multiclass_nms, ops::MultiClassNMSOp);
#endif
#endif
......@@ -2221,7 +2221,10 @@ class ConvTransposeParam : public OpParam {
const Scope &scope) {
filter_ = FilterFrom<GType>(inputs, scope);
input_ = InputFrom<GType>(inputs, scope);
output_ = OutputFrom<GType>(outputs, scope);
// output_ = OutputFrom<GType>(outputs, scope);
if (outputs.count("Output")) {
output_ = OpParam::OutputFrom<GType>(outputs, scope);
}
strides_ = GetAttr<vector<int>>("strides", attrs);
paddings_ = GetAttr<vector<int>>("paddings", attrs);
dilations_ = GetAttr<vector<int>>("dilations", attrs);
......@@ -2262,6 +2265,38 @@ class ConvTransposeParam : public OpParam {
#endif
};
#endif
#ifdef FUSION_DECONVADD_OP
template <typename Dtype>
class FusionDeconvAddParam : public ConvTransposeParam<Dtype> {
typedef typename DtypeTensorTrait<Dtype>::gtype GType;
typedef typename DtypeTensorTrait<Dtype>::rtype RType;
public:
FusionDeconvAddParam(const VariableNameMap &inputs,
const VariableNameMap &outputs,
const AttributeMap &attrs, const Scope &scope)
: ConvTransposeParam<Dtype>(inputs, outputs, attrs, scope) {
bias_ = OpParam::InputYFrom<GType>(inputs, scope);
axis_ = OpParam::GetAttr<int>("axis", attrs);
output_ = OpParam::OutFrom<GType>(outputs, scope);
}
RType *Bias() const { return bias_; }
const int &Axis() const { return axis_; }
RType *Output() const { return output_; }
protected:
RType *bias_;
int axis_;
RType *output_;
};
#endif
#ifdef FUSION_DECONVADDRELU_OP
template <typename Dtype>
using FusionDeconvAddReluParam = FusionDeconvAddParam<Dtype>;
#endif
#ifdef FUSION_DECONVRELU_OP
template <typename Dtype>
......
......@@ -54,5 +54,7 @@ REGISTER_OPERATOR_CPU(prior_box, ops::PriorBoxOp);
#endif
#ifdef PADDLE_MOBILE_MALI_GPU
#endif
#ifdef PADDLE_MOBILE_CL
REGISTER_OPERATOR_CL(prior_box, ops::PriorBoxOp);
#endif
#endif
......@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License. */
#ifdef SPLIT_OP
#include "operators/split_op.h"
#include <vector>
namespace paddle_mobile {
namespace operators {
......@@ -83,5 +83,8 @@ namespace ops = paddle_mobile::operators;
#ifdef PADDLE_MOBILE_CPU
REGISTER_OPERATOR_CPU(split, ops::SplitOp);
#endif
#ifdef PADDLE_MOBILE_FPGA
REGISTER_OPERATOR_FPGA(split, ops::SplitOp);
#endif
#endif // SPLIT_OP
......@@ -29,7 +29,7 @@ void TanhOp<DeviceType, T>::InferShape() const {
namespace ops = paddle_mobile::operators;
#ifdef PADDLE_MOBILE_FPGA
REGISTER_OPERATOR_FPGA(Tanh, ops::TanhOp);
REGISTER_OPERATOR_FPGA(tanh, ops::TanhOp);
#endif
#endif
......@@ -60,5 +60,8 @@ namespace ops = paddle_mobile::operators;
#ifdef PADDLE_MOBILE_CPU
REGISTER_OPERATOR_CPU(transpose2, ops::Transpose2Op);
#endif
#ifdef PADDLE_MOBILE_FPGA
REGISTER_OPERATOR_FPGA(transpose2, ops::Transpose2Op);
#endif
#endif // TRANSPOSE_OP
......@@ -55,5 +55,8 @@ namespace ops = paddle_mobile::operators;
#ifdef PADDLE_MOBILE_CPU
REGISTER_OPERATOR_CPU(transpose, ops::TransposeOp);
#endif
#ifdef PADDLE_MOBILE_CL
REGISTER_OPERATOR_CL(transpose, ops::TransposeOp);
#endif
#endif // TRANSPOSE_OP
......@@ -66,6 +66,10 @@ list(FIND NET "FPGA_NET_V1" CON)
if (CON GREATER -1)
ADD_EXECUTABLE(test-resnet50 fpga/test_resnet50.cpp test_helper.h test_include.h executor_for_test.h)
target_link_libraries(test-resnet50 paddle-mobile)
ADD_EXECUTABLE(test-densebox net/test_densebox_combine.cpp test_helper.h test_include.h executor_for_test.h)
target_link_libraries(test-densebox paddle-mobile)
set(FOUND_MATCH ON)
endif ()
......@@ -76,6 +80,10 @@ if (CON GREATER -1)
ADD_EXECUTABLE(test-pe fpga/test_pe.cpp)
target_link_libraries(test-pe paddle-mobile)
ADD_EXECUTABLE(test-densebox net/test_densebox_combine.cpp test_helper.h test_include.h executor_for_test.h)
target_link_libraries(test-densebox paddle-mobile)
set(FOUND_MATCH ON)
endif ()
......
......@@ -127,6 +127,8 @@ endif()
list(FIND NET "FPGA_NET_V2" CON)
if (CON GREATER -1)
message("FPGA_NET_V2 enabled")
set(FEED_OP ON)
set(FUSION_CONVADDRELU_OP ON)
set(FUSION_ELEMENTWISEADDRELU_OP ON)
set(FUSION_FC_OP ON)
set(POOL_OP ON)
......@@ -135,9 +137,16 @@ if (CON GREATER -1)
set(FUSION_CONVBN_OP ON)
set(CONV_TRANSPOSE_OP ON)
set(FUSION_DECONVRELU_OP ON)
set(SLICE_OP ON)
#set(SLICE_OP ON)
set(TANH_OP ON)
set(ELEMENTWISEADD_OP ON)
set(TRANSPOSE2_OP ON)
set(FUSION_CONVADD_OP ON)
set(SPLIT_OP ON)
set(FUSION_DECONVADD_OP ON)
set(FUSION_DECONVADDRELU_OP ON)
set(FOUND_MATCH ON)
endif()
......@@ -452,4 +461,10 @@ if (TANH_OP)
endif()
if (FUSION_DECONVRELU_OP)
add_definitions(-DFUSION_DECONVRELU_OP)
endif()
if (FUSION_DECONVADD_OP)
add_definitions(-DFUSION_DECONVADD_OP)
endif()
if (FUSION_DECONVADDRELU_OP)
add_definitions(-DFUSION_DECONVADDRELU_OP)
endif()
\ No newline at end of file
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册