提交 7d02bc6d 编写于 作者: H hjchen2

Merge branch 'develop' of https://github.com/PaddlePaddle/paddle-mobile into dev-latest

...@@ -73,6 +73,8 @@ const char *G_OP_TYPE_QUANTIZE = "quantize"; ...@@ -73,6 +73,8 @@ const char *G_OP_TYPE_QUANTIZE = "quantize";
const char *G_OP_TYPE_DEQUANTIZE = "dequantize"; const char *G_OP_TYPE_DEQUANTIZE = "dequantize";
extern const char *G_OP_TYPE_TANH = "tanh"; extern const char *G_OP_TYPE_TANH = "tanh";
extern const char *G_OP_TYPE_FUSION_DECONV_RELU = "fusion_deconv_relu"; extern const char *G_OP_TYPE_FUSION_DECONV_RELU = "fusion_deconv_relu";
extern const char *G_OP_TYPE_FUSION_DECONV_ADD = "fusion_deconv_add";
extern const char *G_OP_TYPE_FUSION_DECONV_ADD_RELU = "fusion_deconv_add_relu";
std::unordered_map< std::unordered_map<
std::string, std::pair<std::vector<std::string>, std::vector<std::string>>> std::string, std::pair<std::vector<std::string>, std::vector<std::string>>>
...@@ -133,5 +135,7 @@ std::unordered_map< ...@@ -133,5 +135,7 @@ std::unordered_map<
{G_OP_TYPE_QUANTIZE, {{"X"}, {"Out", "OutScale"}}}, {G_OP_TYPE_QUANTIZE, {{"X"}, {"Out", "OutScale"}}},
{G_OP_TYPE_DEQUANTIZE, {{"X", "Scale"}, {"Out"}}}, {G_OP_TYPE_DEQUANTIZE, {{"X", "Scale"}, {"Out"}}},
{G_OP_TYPE_TANH, {{"X"}, {"Out"}}}, {G_OP_TYPE_TANH, {{"X"}, {"Out"}}},
{G_OP_TYPE_FUSION_DECONV_RELU, {{"Input"}, {"Out"}}}}; {G_OP_TYPE_FUSION_DECONV_RELU, {{"Input"}, {"Out"}}},
{G_OP_TYPE_FUSION_DECONV_ADD, {{"Input"}, {"Out"}}},
{G_OP_TYPE_FUSION_DECONV_ADD_RELU, {{"Input"}, {"Out"}}}};
} // namespace paddle_mobile } // namespace paddle_mobile
...@@ -142,6 +142,9 @@ extern const char *G_OP_TYPE_DEQUANTIZE; ...@@ -142,6 +142,9 @@ extern const char *G_OP_TYPE_DEQUANTIZE;
extern const char *G_OP_TYPE_TANH; extern const char *G_OP_TYPE_TANH;
extern const char *G_OP_TYPE_FUSION_DECONV_RELU; extern const char *G_OP_TYPE_FUSION_DECONV_RELU;
extern const char *G_OP_TYPE_FUSION_DECONV_ADD;
extern const char *G_OP_TYPE_FUSION_DECONV_ADD_RELU;
extern std::unordered_map< extern std::unordered_map<
std::string, std::pair<std::vector<std::string>, std::vector<std::string>>> std::string, std::pair<std::vector<std::string>, std::vector<std::string>>>
op_input_output_key; op_input_output_key;
......
...@@ -280,6 +280,9 @@ std::shared_ptr<framework::Tensor> Executor<Dtype, P>::Predict( ...@@ -280,6 +280,9 @@ std::shared_ptr<framework::Tensor> Executor<Dtype, P>::Predict(
clock_gettime(CLOCK_MONOTONIC, &ts); clock_gettime(CLOCK_MONOTONIC, &ts);
profile[i].runBegin = (uint64_t)ts.tv_sec * 1e9 + ts.tv_nsec; profile[i].runBegin = (uint64_t)ts.tv_sec * 1e9 + ts.tv_nsec;
#endif #endif
if (loddable_) {
ops[i]->InferShape();
}
// to Run // to Run
ops[i]->Run(); ops[i]->Run();
#ifdef PADDLE_MOBILE_PROFILE #ifdef PADDLE_MOBILE_PROFILE
......
...@@ -43,15 +43,21 @@ void Loader<Dtype, P>::InitMemoryFromProgram( ...@@ -43,15 +43,21 @@ void Loader<Dtype, P>::InitMemoryFromProgram(
tensor->Resize(make_ddim(dim)); tensor->Resize(make_ddim(dim));
} else { } else {
auto dim = var_desc->Tensor_desc().Dims(); auto dim = var_desc->Tensor_desc().Dims();
PADDLE_MOBILE_ENFORCE(dim.size() > 0, "dim size is 0"); // PADDLE_MOBILE_ENFORCE(dim.size() > 0, "dim size is 0");
// dim[0] = 1; // dim[0] = 1;
for (auto &d : dim) { if (dim.size() == 0) {
if (d < 0) { auto tensor = var->GetMutable<LoDTensor>();
d *= -1; framework::DDim dDim = {0};
tensor->Resize(dDim);
} else {
for (auto &d : dim) {
if (d < 0) {
d *= -1;
}
} }
auto tensor = var->GetMutable<LoDTensor>();
tensor->Resize(make_ddim(dim));
} }
auto tensor = var->GetMutable<LoDTensor>();
tensor->Resize(make_ddim(dim));
} }
} else { } else {
// TODO(codeWorm): some. // TODO(codeWorm): some.
......
...@@ -405,9 +405,9 @@ Java_com_baidu_paddle_PML_predictLod(JNIEnv *env, jclass thiz, jlongArray buf) { ...@@ -405,9 +405,9 @@ Java_com_baidu_paddle_PML_predictLod(JNIEnv *env, jclass thiz, jlongArray buf) {
ANDROIDLOGE("predict nlp size %d", count); ANDROIDLOGE("predict nlp size %d", count);
result = env->NewLongArray(count); result = env->NewLongArray(count);
env->SetLongArrayRegion(result, 0, count, vec_result->data<int64_t>()); env->SetLongArrayRegion(result, 0, count, vec_result->data<int64_t>());
env->ReleaseLongArrayElements(buf, ddim_ptr, 0);
return result; return result;
} }
......
...@@ -122,9 +122,12 @@ void PaddleMobile<Dtype, P>::Clear() { ...@@ -122,9 +122,12 @@ void PaddleMobile<Dtype, P>::Clear() {
executor_ = nullptr; executor_ = nullptr;
loader_ = nullptr; loader_ = nullptr;
} }
template <typename Dtype, Precision P> template <typename Dtype, Precision P>
double PaddleMobile<Dtype, P>::GetPredictTime() { double PaddleMobile<Dtype, P>::GetPredictTime() {}
#ifdef PADDLE_MOBILE_CPU
template <>
double PaddleMobile<CPU, Precision::FP32>::GetPredictTime() {
int m = 32; int m = 32;
int n = 224 * 224; int n = 224 * 224;
int k = 27; int k = 27;
...@@ -147,8 +150,8 @@ double PaddleMobile<Dtype, P>::GetPredictTime() { ...@@ -147,8 +150,8 @@ double PaddleMobile<Dtype, P>::GetPredictTime() {
} }
paddle_mobile::operators::math::Gemm gemm; paddle_mobile::operators::math::Gemm gemm;
auto time1 = paddle_mobile::time(); auto time1 = paddle_mobile::time();
// gemm.Sgemm(m, n, k, static_cast<float>(1), a, lda, b, ldb, gemm.Sgemm(m, n, k, static_cast<float>(1), a, lda, b, ldb,
// static_cast<float>(0), c, ldc, false, nullptr); static_cast<float>(0), c, ldc, false, nullptr);
auto time2 = paddle_mobile::time(); auto time2 = paddle_mobile::time();
double cost = paddle_mobile::time_diff(time1, time2); double cost = paddle_mobile::time_diff(time1, time2);
paddle_mobile::memory::Free(a); paddle_mobile::memory::Free(a);
...@@ -156,6 +159,7 @@ double PaddleMobile<Dtype, P>::GetPredictTime() { ...@@ -156,6 +159,7 @@ double PaddleMobile<Dtype, P>::GetPredictTime() {
paddle_mobile::memory::Free(c); paddle_mobile::memory::Free(c);
return cost; return cost;
} }
#endif
template <typename Dtype, Precision P> template <typename Dtype, Precision P>
PaddleMobile<Dtype, P>::~PaddleMobile() { PaddleMobile<Dtype, P>::~PaddleMobile() {
......
...@@ -21,7 +21,13 @@ template <typename DeviceType, typename T> ...@@ -21,7 +21,13 @@ template <typename DeviceType, typename T>
void FeedOp<DeviceType, T>::InferShape() const { void FeedOp<DeviceType, T>::InferShape() const {
auto out_dims = this->param_.Out()->dims(); auto out_dims = this->param_.Out()->dims();
out_dims[0] = this->param_.BatchSize(); out_dims[0] = this->param_.BatchSize();
this->param_.Out()->Resize(out_dims); auto input_dims = this->param_.InputX()->dims();
DLOG << input_dims.size();
if (input_dims.size() == 4) {
this->param_.Out()->Resize(input_dims);
} else {
this->param_.Out()->Resize(out_dims);
}
} }
} // namespace operators } // namespace operators
......
...@@ -61,5 +61,7 @@ REGISTER_OPERATOR_MALI_GPU(fusion_conv_add, ops::FusionConvAddOp); ...@@ -61,5 +61,7 @@ REGISTER_OPERATOR_MALI_GPU(fusion_conv_add, ops::FusionConvAddOp);
#ifdef PADDLE_MOBILE_CL #ifdef PADDLE_MOBILE_CL
REGISTER_OPERATOR_CL(fusion_conv_add, ops::FusionConvAddOp); REGISTER_OPERATOR_CL(fusion_conv_add, ops::FusionConvAddOp);
#endif #endif
#ifdef PADDLE_MOBILE_FPGA
REGISTER_OPERATOR_FPGA(fusion_conv_add, ops::FusionConvAddOp);
#endif
#endif #endif
...@@ -29,8 +29,9 @@ namespace operators { ...@@ -29,8 +29,9 @@ namespace operators {
class FusionConvAddReluOpMatcher : public framework::FusionOpMatcher { class FusionConvAddReluOpMatcher : public framework::FusionOpMatcher {
public: public:
FusionConvAddReluOpMatcher() { FusionConvAddReluOpMatcher() {
// node_ = framework::Node(G_OP_TYPE_FUSION_CONV_ADD); node_ = framework::Node(G_OP_TYPE_CONV);
// node_ > std::make_shared<framework::Node>(G_OP_TYPE_RELU); node_ > std::make_shared<framework::Node>(G_OP_TYPE_ELEMENTWISE_ADD) >
std::make_shared<framework::Node>(G_OP_TYPE_RELU);
} }
void FolderNodes( void FolderNodes(
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef FUSION_DECONVADD_OP
#include "operators/fusion_deconv_add_op.h"
namespace paddle_mobile {
namespace operators {}
} // namespace paddle_mobile
namespace ops = paddle_mobile::operators;
REGISTER_FUSION_MATCHER(fusion_deconv_add, ops::FusionDeconvAddMatcher);
#ifdef PADDLE_MOBILE_CPU
#endif
#ifdef PADDLE_MOBILE_MALI_GPU
#endif
#ifdef PADDLE_MOBILE_FPGA
REGISTER_OPERATOR_FPGA(fusion_deconv_add, ops::FusionDeconvAddOp);
#endif
#endif
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef FUSION_DECONVADD_OP
#pragma once
#include <string>
#include <vector>
#include "framework/operator.h"
#include "framework/program/program-optimize/fusion_op_register.h"
#include "operators/kernel/deconv_add_kernel.h"
namespace paddle_mobile {
namespace operators {
using std::string;
using std::vector;
class FusionDeconvAddMatcher : public framework::FusionOpMatcher {
public:
FusionDeconvAddMatcher() {
node_ = framework::Node(G_OP_TYPE_CONV_TRANSPOSE);
node_ > std::make_shared<framework::Node>(G_OP_TYPE_ELEMENTWISE_ADD);
}
void FolderNodes(
framework::Node *node,
std::vector<std::shared_ptr<framework::Node>> *removed_nodes) {
node->Folder(node_.Depth(), Type(),
{{G_OP_TYPE_ELEMENTWISE_ADD, {{"Y", "Y"}}}}, removed_nodes);
}
std::string Type() { return G_OP_TYPE_FUSION_DECONV_ADD; }
};
template <typename DeviceType, typename T>
class FusionDeconvAddOp : public framework::OperatorWithKernel<
DeviceType, FusionDeconvAddParam<DeviceType>,
operators::DeconvAddKernel<DeviceType, T>> {
public:
FusionDeconvAddOp(const string &type, const VariableNameMap &inputs,
const VariableNameMap &outputs,
const framework::AttributeMap &attrs,
std::shared_ptr<framework::Scope> scope)
: framework::OperatorWithKernel<
DeviceType, FusionDeconvAddParam<DeviceType>,
operators::DeconvAddKernel<DeviceType, T>>(type, inputs, outputs,
attrs, scope) {}
void InferShape() const {
auto input = this->param_.Input();
auto in_dims = input->dims();
auto filter = this->param_.Filter();
auto filter_dims = filter->dims();
std::vector<int> strides = this->param_.Strides();
std::vector<int> paddings = this->param_.Paddings();
std::vector<int> dilations = this->param_.Dilations();
int groups = this->param_.Groups();
PADDLE_MOBILE_ENFORCE(
in_dims.size() == 4 || in_dims.size() == 5,
"ConvTransposeOp intput should be 4-D or 5-D tensor.");
PADDLE_MOBILE_ENFORCE(
in_dims.size() == filter_dims.size(),
"ConvTransposeOp input dimension and filter dimension "
"should be the same.");
PADDLE_MOBILE_ENFORCE(
in_dims.size() - strides.size() == 2U,
"ConvTransposeOp input dimension and strides dimension should "
"be consistent.");
PADDLE_MOBILE_ENFORCE(paddings.size() == strides.size(),
"ConvTransposeOp paddings dimension and strides "
"dimension should be the same.");
PADDLE_MOBILE_ENFORCE(paddings.size() == dilations.size(),
"ConvTransposeOp paddings dimension and dilations "
"dimension should be the same.");
PADDLE_MOBILE_ENFORCE(
in_dims[1] == filter_dims[0],
"In ConvTransposeOp, The number of input channels should "
"be equal to the number of filter's channels.");
std::vector<int64_t> output_shape({in_dims[0], filter_dims[1] * groups});
for (size_t i = 0; i < strides.size(); ++i) {
auto filter_extent = dilations[i] * (filter_dims[i + 2] - 1) + 1;
output_shape.push_back((in_dims[i + 2] - 1) * strides[i] -
2 * paddings[i] + filter_extent);
}
this->param_.Output()->Resize(framework::make_ddim(output_shape));
}
protected:
};
} // namespace operators
} // namespace paddle_mobile
#endif // FUSION_DECONV_ADD_OP
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef FUSION_DECONVADDRELU_OP
#include "operators/fusion_deconv_add_relu_op.h"
namespace paddle_mobile {
namespace operators {}
} // namespace paddle_mobile
namespace ops = paddle_mobile::operators;
REGISTER_FUSION_MATCHER(fusion_deconv_add_relu,
ops::FusionDeconvAddReluMatcher);
#ifdef PADDLE_MOBILE_CPU
#endif
#ifdef PADDLE_MOBILE_MALI_GPU
#endif
#ifdef PADDLE_MOBILE_FPGA
REGISTER_OPERATOR_FPGA(fusion_deconv_add_relu, ops::FusionDeconvAddReluOp);
#endif
#endif
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef FUSION_DECONVADDRELU_OP
#pragma once
#include <string>
#include <vector>
#include "framework/operator.h"
#include "framework/program/program-optimize/fusion_op_register.h"
#include "operators/kernel/deconv_add_relu_kernel.h"
namespace paddle_mobile {
namespace operators {
using std::string;
using std::vector;
class FusionDeconvAddReluMatcher : public framework::FusionOpMatcher {
public:
FusionDeconvAddReluMatcher() {
node_ = framework::Node(G_OP_TYPE_CONV_TRANSPOSE);
node_ > std::make_shared<framework::Node>(G_OP_TYPE_ELEMENTWISE_ADD) >
std::make_shared<framework::Node>(G_OP_TYPE_RELU);
}
void FolderNodes(
framework::Node *node,
std::vector<std::shared_ptr<framework::Node>> *removed_nodes) {
node->Folder(node_.Depth(), Type(),
{{G_OP_TYPE_ELEMENTWISE_ADD, {{"Y", "Y"}}}}, removed_nodes);
}
std::string Type() { return G_OP_TYPE_FUSION_DECONV_ADD_RELU; }
};
template <typename DeviceType, typename T>
class FusionDeconvAddReluOp
: public framework::OperatorWithKernel<
DeviceType, FusionDeconvAddReluParam<DeviceType>,
operators::DeconvAddReluKernel<DeviceType, T>> {
public:
FusionDeconvAddReluOp(const string &type, const VariableNameMap &inputs,
const VariableNameMap &outputs,
const framework::AttributeMap &attrs,
std::shared_ptr<framework::Scope> scope)
: framework::OperatorWithKernel<
DeviceType, FusionDeconvAddReluParam<DeviceType>,
operators::DeconvAddReluKernel<DeviceType, T>>(
type, inputs, outputs, attrs, scope) {}
void InferShape() const {
auto input = this->param_.Input();
auto in_dims = input->dims();
auto filter = this->param_.Filter();
auto filter_dims = filter->dims();
std::vector<int> strides = this->param_.Strides();
std::vector<int> paddings = this->param_.Paddings();
std::vector<int> dilations = this->param_.Dilations();
int groups = this->param_.Groups();
PADDLE_MOBILE_ENFORCE(
in_dims.size() == 4 || in_dims.size() == 5,
"ConvTransposeOp intput should be 4-D or 5-D tensor.");
PADDLE_MOBILE_ENFORCE(
in_dims.size() == filter_dims.size(),
"ConvTransposeOp input dimension and filter dimension "
"should be the same.");
PADDLE_MOBILE_ENFORCE(
in_dims.size() - strides.size() == 2U,
"ConvTransposeOp input dimension and strides dimension should "
"be consistent.");
PADDLE_MOBILE_ENFORCE(paddings.size() == strides.size(),
"ConvTransposeOp paddings dimension and strides "
"dimension should be the same.");
PADDLE_MOBILE_ENFORCE(paddings.size() == dilations.size(),
"ConvTransposeOp paddings dimension and dilations "
"dimension should be the same.");
PADDLE_MOBILE_ENFORCE(
in_dims[1] == filter_dims[0],
"In ConvTransposeOp, The number of input channels should "
"be equal to the number of filter's channels.");
std::vector<int64_t> output_shape({in_dims[0], filter_dims[1] * groups});
for (size_t i = 0; i < strides.size(); ++i) {
auto filter_extent = dilations[i] * (filter_dims[i + 2] - 1) + 1;
output_shape.push_back((in_dims[i + 2] - 1) * strides[i] -
2 * paddings[i] + filter_extent);
}
this->param_.Output()->Resize(framework::make_ddim(output_shape));
}
protected:
};
} // namespace operators
} // namespace paddle_mobile
#endif // FUSION_DECONV_ADD_RELU_OP
...@@ -115,6 +115,7 @@ void ConvAddBasic(const FusionConvAddParam<CPU> &param) { ...@@ -115,6 +115,7 @@ void ConvAddBasic(const FusionConvAddParam<CPU> &param) {
template <typename P> template <typename P>
void ConvAddCompute(const FusionConvAddParam<CPU> &param) { void ConvAddCompute(const FusionConvAddParam<CPU> &param) {
param.Output()->mutable_data<float>();
if (param.Groups() == param.Input()->dims()[1] && if (param.Groups() == param.Input()->dims()[1] &&
param.Input()->dims()[1] == param.Output()->dims()[1] && param.Input()->dims()[1] == param.Output()->dims()[1] &&
param.Filter()->dims()[2] == param.Filter()->dims()[3] && param.Filter()->dims()[2] == param.Filter()->dims()[3] &&
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#define BATCH_NORM
#define RELU
#include "conv_kernel.inc.cl"
...@@ -20,7 +20,8 @@ __kernel void fetch(__private const int in_height, ...@@ -20,7 +20,8 @@ __kernel void fetch(__private const int in_height,
__global float* out, __global float* out,
__private const int size_ch, __private const int size_ch,
__private const int size_block, __private const int size_block,
__private const int size_batch) { __private const int size_batch,
__private const int C) {
const int in_c = get_global_id(0); const int in_c = get_global_id(0);
const int in_w = get_global_id(1); const int in_w = get_global_id(1);
const int in_nh = get_global_id(2); const int in_nh = get_global_id(2);
...@@ -35,9 +36,17 @@ __kernel void fetch(__private const int in_height, ...@@ -35,9 +36,17 @@ __kernel void fetch(__private const int in_height,
const int index = in_n * size_batch + in_c * size_block + in_h * in_width + in_w; const int index = in_n * size_batch + in_c * size_block + in_h * in_width + in_w;
out[index] = convert_float(in.x); out[index] = convert_float(in.x);
out[index + size_ch] = convert_float(in.y); if(C - 4 * in_c>=2){
out[index + size_ch] = convert_float(in.y);
}
if(C - 4 * in_c>=3){
out[index + size_ch * 2] = convert_float(in.z); out[index + size_ch * 2] = convert_float(in.z);
out[index + size_ch * 3] = convert_float(in.w); }
if(C - 4 * in_c>=4){
out[index + size_ch * 3] = convert_float(in.w);
}
} }
__kernel void fetch_2d(__private const int in_height, __kernel void fetch_2d(__private const int in_height,
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
__kernel void prior_box(__private const int global_size_dim0,
__private const int global_size_dim1,
__private const int global_size_dim2,
__global float *box_width,
__global float *box_height,
__write_only image2d_t output_image,
__private const float step_width,
__private const float step_height,
__private const float offset,
__private const int img_width,
__private const int img_height,
__private const int num_priors,
__private const int C){
const int out_c = get_global_id(0);
const int out_nh = get_global_id(1);
const int out_n = out_nh/num_priors;
const int out_h = out_nh%num_priors;
if (out_c >= global_size_dim0 ||out_nh >= global_size_dim2) {
return;
}
const sampler_t sampler = CLK_NORMALIZED_COORDS_TRUE |
CLK_ADDRESS_CLAMP |
CLK_FILTER_NEAREST;
int2 output_pos;
output_pos.x = out_c * 4;
output_pos.y = out_nh;
float center_x0 = (offset + out_c * 4) * step_width;
float center_x1 = (offset + out_c * 4 + 1) * step_width;
float center_x2 = (offset + out_c * 4 + 2) * step_width;
float center_x3 = (offset + out_c * 4 + 3) * step_width;
float center_y = (out_n + offset) * step_height;
half4 output[4];
output[0].x = convert_half((center_x0 - box_width[out_h]) / img_width);
output[1].x = convert_half((center_y - box_height[out_h]) / img_height);
output[2].x = convert_half((center_x0 + box_width[out_h]) / img_width);
output[3].x = convert_half((center_y + box_height[out_h]) / img_height);
if(C - 4 * out_c>=2){
output[0].y = convert_half((center_x1 - box_width[out_h]) / img_width);
output[1].y = convert_half((center_y - box_height[out_h]) / img_height);
output[2].y = convert_half((center_x1 + box_width[out_h]) / img_width);
output[3].y = convert_half((center_y + box_height[out_h]) / img_height);
}else{
output[0].y = 0.0f;
output[1].y = 0.0f;
output[2].y = 0.0f;
output[3].y = 0.0f;
}
if(C - 4 * out_c>=3){
output[0].z = convert_half((center_x2 - box_width[out_h]) / img_width);
output[1].z = convert_half((center_y - box_height[out_h]) / img_height);
output[2].z = convert_half((center_x2 + box_width[out_h]) / img_width);
output[3].z = convert_half((center_y + box_height[out_h]) / img_height);
}else{
output[0].z = 0.0f;
output[1].z = 0.0f;
output[2].z = 0.0f;
output[3].z = 0.0f;
}
if(C - 4 * out_c>=4){
output[0].w = convert_half((center_x3 - box_width[out_h]) / img_width);
output[1].w = convert_half((center_y - box_height[out_h]) / img_height);
output[2].w = convert_half((center_x3 + box_width[out_h]) / img_width);
output[3].w = convert_half((center_y + box_height[out_h]) / img_height);
}else{
output[0].z = 0.0f;
output[1].z = 0.0f;
output[2].z = 0.0f;
output[3].z = 0.0f;
}
output[0] = min(max((half4)(0.0f, 0.0f, 0.0f, 0.0f), output[0]),(half4)(1.0f, 1.0f, 1.0f, 1.0f));
output[1] = min(max((half4)(0.0f, 0.0f, 0.0f, 0.0f), output[1]),(half4)(1.0f, 1.0f, 1.0f, 1.0f));
output[2] = min(max((half4)(0.0f, 0.0f, 0.0f, 0.0f), output[2]),(half4)(1.0f, 1.0f, 1.0f, 1.0f));
output[3] = min(max((half4)(0.0f, 0.0f, 0.0f, 0.0f), output[3]),(half4)(1.0f, 1.0f, 1.0f, 1.0f));
write_imageh(output_image, (int2)(output_pos.x + 1, output_pos.y), output[0]);
write_imageh(output_image, (int2)(output_pos.x + 2, output_pos.y), output[1]);
write_imageh(output_image, (int2)(output_pos.x + 3, output_pos.y), output[2]);
write_imageh(output_image, (int2)(output_pos.x + 4, output_pos.y), output[3]);
}
\ No newline at end of file
...@@ -68,10 +68,10 @@ void ConvAddKernel<GPU_CL, float>::Compute( ...@@ -68,10 +68,10 @@ void ConvAddKernel<GPU_CL, float>::Compute(
int nh = default_work_size[2]; int nh = default_work_size[2];
auto input = param.Input()->GetCLImage(); auto input = param.Input()->GetCLImage();
auto filter = param.Filter()->GetCLImage(); auto filter = param.Filter()->GetCLImage();
DLOG << "---yangfei30---";
DLOG << *param.Filter();
DLOG << param.Paddings();
auto biase = param.Bias()->GetCLImage(); auto biase = param.Bias()->GetCLImage();
param.Output()->InitEmptyImage(cl_helper_.CLContext(),
cl_helper_.CLCommandQueue(),
param.Output()->dims());
auto output = param.Output()->GetCLImage(); auto output = param.Output()->GetCLImage();
int stride = param.Strides()[0]; int stride = param.Strides()[0];
int offset = param.Offset(); int offset = param.Offset();
......
...@@ -22,12 +22,185 @@ namespace operators { ...@@ -22,12 +22,185 @@ namespace operators {
template <> template <>
bool ConvBNReluKernel<GPU_CL, float>::Init( bool ConvBNReluKernel<GPU_CL, float>::Init(
FusionConvBNReluParam<GPU_CL> *param) { FusionConvBNReluParam<GPU_CL> *param) {
PADDLE_MOBILE_ENFORCE(
param->Filter()->dims()[2] == param->Filter()->dims()[3] &&
param->Paddings()[0] == param->Paddings()[1],
"need equal");
const framework::CLImage *mean = param->InputMean();
const framework::CLImage *variance = param->InputVariance();
const framework::CLImage *scale = param->InputScale();
const framework::CLImage *bias = param->InputBias();
const float epsilon = param->Epsilon();
const int C = mean->numel();
auto mean_ptr = mean->data<float>();
auto variance_ptr = variance->data<float>();
auto scale_ptr = scale->data<float>();
auto bias_ptr = bias->data<float>();
float inv_std_ptr[C];
for (int i = 0; i < C; i++) {
inv_std_ptr[i] =
1 / static_cast<float>(pow((variance_ptr[i] + epsilon), 0.5));
}
float *new_scale_ptr = new float[C];
float *new_bias_ptr = new float[C];
for (int i = 0; i < C; i++) {
new_scale_ptr[i] = inv_std_ptr[i] * scale_ptr[i];
new_bias_ptr[i] = bias_ptr[i] - mean_ptr[i] * inv_std_ptr[i] * scale_ptr[i];
}
framework::CLImage *new_scale = new framework::CLImage();
// for (int j = 0; j < C; ++j) {
// DLOG << " new scale - " << j << new_scale_ptr[j];
// }
//
// for (int j = 0; j < C; ++j) {
// DLOG << " new bias - " << j << new_bias_ptr[j];
// }
new_scale->SetTensorData(new_scale_ptr, variance->dims());
new_scale->InitCLImage(this->cl_helper_.CLContext(),
cl_helper_.CLCommandQueue());
// DLOG << " climage - y bias: " << *(param->Bias());
//
// DLOG << " climage - new scale: " << *new_scale;
framework::CLImage *new_bias = new framework::CLImage();
new_bias->SetTensorData(new_bias_ptr, variance->dims());
new_bias->InitCLImage(this->cl_helper_.CLContext(),
cl_helper_.CLCommandQueue());
// DLOG << " climage - new bias: " << *new_bias;
//
// DLOG << " climage - filter: " << *(param->Filter());
param->SetNewScale(new_scale);
param->SetNewBias(new_bias);
delete[](new_scale_ptr);
delete[](new_bias_ptr);
PADDLE_MOBILE_ENFORCE(
param->Filter()->dims()[2] == param->Filter()->dims()[3] &&
param->Paddings()[0] == param->Paddings()[1],
"need equal");
int offset = static_cast<int>(param->Filter()->dims()[2]) / 2 -
static_cast<int>(param->Paddings()[1]);
param->SetOffset(offset);
if (param->Filter()->dims()[2] == 1 && param->Filter()->dims()[3] == 1) {
param->Filter()->InitNImage(cl_helper_.CLContext(),
cl_helper_.CLCommandQueue());
this->cl_helper_.AddKernel("conv_1x1", "conv_bn_relu_kernel.cl");
DLOG << " conv bn relu conv 1x1";
} else if (param->Filter()->dims()[1] == 1 &&
param->Input()->dims()[1] == param->Output()->dims()[1] &&
param->Filter()->dims()[2] == 3) {
param->Filter()->InitDWImage(cl_helper_.CLContext(),
cl_helper_.CLCommandQueue());
this->cl_helper_.AddKernel("depth_conv_3x3", "conv_bn_relu_kernel.cl");
DLOG << " conv bn relu depth_conv_3x3";
} else if (param->Filter()->dims()[2] == 3 &&
param->Filter()->dims()[3] == 3) {
param->Filter()->InitCLImage(cl_helper_.CLContext(),
cl_helper_.CLCommandQueue());
this->cl_helper_.AddKernel("conv_3x3", "conv_bn_relu_kernel.cl");
DLOG << " conv bn relu conv_3x3";
} else {
PADDLE_MOBILE_THROW_EXCEPTION(" not support ");
}
return true; return true;
} }
template <> template <>
void ConvBNReluKernel<GPU_CL, float>::Compute( void ConvBNReluKernel<GPU_CL, float>::Compute(
const FusionConvBNReluParam<GPU_CL> &param) {} const FusionConvBNReluParam<GPU_CL> &param) {
auto kernel = this->cl_helper_.KernelAt(0);
auto default_work_size = this->cl_helper_.DefaultWorkSize(*param.Output());
int c_block = default_work_size[0];
int w = default_work_size[1];
int nh = default_work_size[2];
auto input = param.Input()->GetCLImage();
auto filter = param.Filter()->GetCLImage();
auto new_scale = param.NewScale()->GetCLImage();
auto new_bias = param.NewBias()->GetCLImage();
auto output = param.Output()->GetCLImage();
int stride = param.Strides()[0];
int offset = param.Offset();
int input_c = reinterpret_cast<framework::CLImageConverterFolder *>(
param.Input()->Converter())
->GetCBlock();
int dilation = param.Dilations()[0];
int input_width = param.Input()->dims()[3];
int input_height = param.Input()->dims()[2];
int output_width = param.Output()->dims()[3];
int output_height = param.Output()->dims()[2];
cl_int status;
status = clSetKernelArg(kernel, 0, sizeof(int), &c_block);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 1, sizeof(int), &w);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 2, sizeof(int), &nh);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 3, sizeof(cl_mem), &input);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 4, sizeof(cl_mem), &filter);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 5, sizeof(cl_mem), &new_scale);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 6, sizeof(cl_mem), &new_bias);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 7, sizeof(cl_mem), &output);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 8, sizeof(int), &stride);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 9, sizeof(int), &offset);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 10, sizeof(int), &input_c);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 11, sizeof(int), &dilation);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 12, sizeof(int), &input_width);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 13, sizeof(int), &input_height);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 14, sizeof(int), &output_width);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 15, sizeof(int), &output_height);
CL_CHECK_ERRORS(status);
status = clEnqueueNDRangeKernel(
this->cl_helper_.CLCommandQueue(), kernel, default_work_size.size(), NULL,
default_work_size.data(), NULL, 0, NULL, NULL);
CL_CHECK_ERRORS(status);
}
template class ConvBNReluKernel<GPU_CL, float>; template class ConvBNReluKernel<GPU_CL, float>;
} // namespace operators } // namespace operators
......
...@@ -22,12 +22,151 @@ namespace operators { ...@@ -22,12 +22,151 @@ namespace operators {
template <> template <>
bool DWConvBNReluKernel<GPU_CL, float>::Init( bool DWConvBNReluKernel<GPU_CL, float>::Init(
FusionDWConvBNReluParam<GPU_CL> *param) { FusionDWConvBNReluParam<GPU_CL> *param) {
PADDLE_MOBILE_ENFORCE(
param->Filter()->dims()[2] == param->Filter()->dims()[3] &&
param->Paddings()[0] == param->Paddings()[1],
"need equal");
const framework::CLImage *mean = param->InputMean();
const framework::CLImage *variance = param->InputVariance();
const framework::CLImage *scale = param->InputScale();
const framework::CLImage *bias = param->InputBias();
const float epsilon = param->Epsilon();
const int C = mean->numel();
auto mean_ptr = mean->data<float>();
auto variance_ptr = variance->data<float>();
auto scale_ptr = scale->data<float>();
auto bias_ptr = bias->data<float>();
float inv_std_ptr[C];
for (int i = 0; i < C; i++) {
inv_std_ptr[i] =
1 / static_cast<float>(pow((variance_ptr[i] + epsilon), 0.5));
}
float *new_scale_ptr = new float[C];
float *new_bias_ptr = new float[C];
for (int i = 0; i < C; i++) {
new_scale_ptr[i] = inv_std_ptr[i] * scale_ptr[i];
new_bias_ptr[i] = bias_ptr[i] - mean_ptr[i] * inv_std_ptr[i] * scale_ptr[i];
}
framework::CLImage *new_scale = new framework::CLImage();
new_scale->SetTensorData(new_scale_ptr, variance->dims());
new_scale->InitCLImage(this->cl_helper_.CLContext(),
cl_helper_.CLCommandQueue());
framework::CLImage *new_bias = new framework::CLImage();
new_bias->SetTensorData(new_bias_ptr, variance->dims());
new_bias->InitCLImage(this->cl_helper_.CLContext(),
cl_helper_.CLCommandQueue());
param->SetNewScale(new_scale);
param->SetNewBias(new_bias);
delete[](new_scale_ptr);
delete[](new_bias_ptr);
PADDLE_MOBILE_ENFORCE(
param->Filter()->dims()[2] == param->Filter()->dims()[3] &&
param->Paddings()[0] == param->Paddings()[1],
"need equal");
int offset = static_cast<int>(param->Filter()->dims()[2]) / 2 -
static_cast<int>(param->Paddings()[1]);
param->SetOffset(offset);
param->Filter()->InitDWImage(cl_helper_.CLContext(),
cl_helper_.CLCommandQueue());
this->cl_helper_.AddKernel("depth_conv_3x3", "conv_bn_relu_kernel.cl");
DLOG << " conv bn relu depth_conv_3x3";
return true; return true;
} }
template <> template <>
void DWConvBNReluKernel<GPU_CL, float>::Compute( void DWConvBNReluKernel<GPU_CL, float>::Compute(
const FusionDWConvBNReluParam<GPU_CL> &param) {} const FusionDWConvBNReluParam<GPU_CL> &param) {
auto kernel = this->cl_helper_.KernelAt(0);
auto default_work_size = this->cl_helper_.DefaultWorkSize(*param.Output());
int c_block = default_work_size[0];
int w = default_work_size[1];
int nh = default_work_size[2];
auto input = param.Input()->GetCLImage();
auto filter = param.Filter()->GetCLImage();
auto new_scale = param.NewScale()->GetCLImage();
auto new_bias = param.NewBias()->GetCLImage();
auto output = param.Output()->GetCLImage();
int stride = param.Strides()[0];
int offset = param.Offset();
int input_c = reinterpret_cast<framework::CLImageConverterFolder *>(
param.Input()->Converter())
->GetCBlock();
int dilation = param.Dilations()[0];
int input_width = param.Input()->dims()[3];
int input_height = param.Input()->dims()[2];
int output_width = param.Output()->dims()[3];
int output_height = param.Output()->dims()[2];
cl_int status;
status = clSetKernelArg(kernel, 0, sizeof(int), &c_block);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 1, sizeof(int), &w);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 2, sizeof(int), &nh);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 3, sizeof(cl_mem), &input);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 4, sizeof(cl_mem), &filter);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 5, sizeof(cl_mem), &new_scale);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 6, sizeof(cl_mem), &new_bias);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 7, sizeof(cl_mem), &output);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 8, sizeof(int), &stride);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 9, sizeof(int), &offset);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 10, sizeof(int), &input_c);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 11, sizeof(int), &dilation);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 12, sizeof(int), &input_width);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 13, sizeof(int), &input_height);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 14, sizeof(int), &output_width);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 15, sizeof(int), &output_height);
CL_CHECK_ERRORS(status);
status = clEnqueueNDRangeKernel(
this->cl_helper_.CLCommandQueue(), kernel, default_work_size.size(), NULL,
default_work_size.data(), NULL, 0, NULL, NULL);
CL_CHECK_ERRORS(status);
}
template class DWConvBNReluKernel<GPU_CL, float>; template class DWConvBNReluKernel<GPU_CL, float>;
} // namespace operators } // namespace operators
......
...@@ -28,6 +28,8 @@ template <> ...@@ -28,6 +28,8 @@ template <>
void FeedKernel<GPU_CL, float>::Compute(const FeedParam<GPU_CL> &param) { void FeedKernel<GPU_CL, float>::Compute(const FeedParam<GPU_CL> &param) {
auto kernel = this->cl_helper_.KernelAt(0); auto kernel = this->cl_helper_.KernelAt(0);
cl_int status; cl_int status;
param.Out()->InitEmptyImage(cl_helper_.CLContext(),
cl_helper_.CLCommandQueue(), param.Out()->dims());
auto output = param.Out(); auto output = param.Out();
const Tensor *input = param.InputX(); const Tensor *input = param.InputX();
// DLOG << *input; // DLOG << *input;
......
...@@ -27,8 +27,6 @@ bool FetchKernel<GPU_CL, float>::Init(FetchParam<GPU_CL> *param) { ...@@ -27,8 +27,6 @@ bool FetchKernel<GPU_CL, float>::Init(FetchParam<GPU_CL> *param) {
} else { } else {
this->cl_helper_.AddKernel("fetch", "fetch_kernel.cl"); this->cl_helper_.AddKernel("fetch", "fetch_kernel.cl");
} }
auto *out = param->Out();
out->mutable_data<float>();
return true; return true;
} }
...@@ -39,7 +37,7 @@ void FetchKernel<GPU_CL, float>::Compute(const FetchParam<GPU_CL> &param) { ...@@ -39,7 +37,7 @@ void FetchKernel<GPU_CL, float>::Compute(const FetchParam<GPU_CL> &param) {
auto input = param.InputX()->GetCLImage(); auto input = param.InputX()->GetCLImage();
auto *out = param.Out(); auto *out = param.Out();
out->mutable_data<float>();
const auto &dim = param.InputX()->dims(); const auto &dim = param.InputX()->dims();
size_t new_dims[] = {1, 1, 1, 1}; size_t new_dims[] = {1, 1, 1, 1};
...@@ -70,9 +68,11 @@ void FetchKernel<GPU_CL, float>::Compute(const FetchParam<GPU_CL> &param) { ...@@ -70,9 +68,11 @@ void FetchKernel<GPU_CL, float>::Compute(const FetchParam<GPU_CL> &param) {
int size_ch = in_height * in_width; int size_ch = in_height * in_width;
int size_block = size_ch * 4; int size_block = size_ch * 4;
int size_batch = size_ch * C; int size_batch = size_ch * C;
int out_c = new_dims[1];
clSetKernelArg(kernel, 4, sizeof(int), &size_ch); clSetKernelArg(kernel, 4, sizeof(int), &size_ch);
clSetKernelArg(kernel, 5, sizeof(int), &size_block); clSetKernelArg(kernel, 5, sizeof(int), &size_block);
clSetKernelArg(kernel, 6, sizeof(int), &size_batch); clSetKernelArg(kernel, 6, sizeof(int), &size_batch);
clSetKernelArg(kernel, 7, sizeof(int), &out_c);
} }
// cl_event wait_event = param.InpdutX()->GetClEvent(); // cl_event wait_event = param.InpdutX()->GetClEvent();
...@@ -93,6 +93,8 @@ void FetchKernel<GPU_CL, float>::Compute(const FetchParam<GPU_CL> &param) { ...@@ -93,6 +93,8 @@ void FetchKernel<GPU_CL, float>::Compute(const FetchParam<GPU_CL> &param) {
// << "ms" << std::endl; // << "ms" << std::endl;
memcpy(out->data<float>(), out_cl_tensor.Data<float>(), out->memory_size()); memcpy(out->data<float>(), out_cl_tensor.Data<float>(), out->memory_size());
DLOG << *param.InputX();
DLOG << *out;
} }
template class FetchKernel<GPU_CL, float>; template class FetchKernel<GPU_CL, float>;
......
...@@ -15,18 +15,165 @@ limitations under the License. */ ...@@ -15,18 +15,165 @@ limitations under the License. */
#ifdef PRIORBOX_OP #ifdef PRIORBOX_OP
#include "operators/kernel/prior_box_kernel.h" #include "operators/kernel/prior_box_kernel.h"
#include "framework/cl/cl_tensor.h"
namespace paddle_mobile { namespace paddle_mobile {
namespace operators { namespace operators {
template <> template <>
bool PriorBoxKernel<GPU_CL, float>::Init(PriorBoxParam<GPU_CL> *param) { bool PriorBoxKernel<GPU_CL, float>::Init(PriorBoxParam<GPU_CL> *param) {
this->cl_helper_.AddKernel("prior_box", "prior_box_kernel.cl");
return true; return true;
} }
template <> template <>
void PriorBoxKernel<GPU_CL, float>::Compute( void PriorBoxKernel<GPU_CL, float>::Compute(
const PriorBoxParam<GPU_CL> &param) {} const PriorBoxParam<GPU_CL> &param) {
const auto *input_ = param.Input();
const auto &input_dims = input_->dims();
const auto &input_image_dims = param.InputImage()->dims();
const auto &min_sizes = param.MinSizes();
const auto &max_sizes = param.MaxSizes();
const auto &variances = param.Variances();
const auto &input_aspect_ratio = param.AspectRatios();
const bool &flip = param.Flip();
const bool &clip = param.Clip();
const float &step_w = param.StepW();
const float &step_h = param.StepH();
const float &offset = param.Offset();
const int C = param.OutputBoxes()->dims()[1];
auto output_boxes = param.OutputBoxes()->GetCLImage();
auto output_variances = param.OutputVariances()->GetCLImage();
std::vector<float> aspect_ratios;
ExpandAspectRatios(input_aspect_ratio, flip, &aspect_ratios);
auto img_width = input_image_dims[3];
auto img_height = input_image_dims[2];
auto feature_width = input_dims[3];
auto feature_height = input_dims[2];
float step_width, step_height;
/// 300 / 19
if (step_w == 0 || step_h == 0) {
step_width = static_cast<float>(img_width) / feature_width;
step_height = static_cast<float>(img_height) / feature_height;
} else {
step_width = step_w;
step_height = step_h;
}
int num_priors = aspect_ratios.size() * min_sizes.size();
if (!max_sizes.empty()) {
num_priors += max_sizes.size();
}
float *box_width = static_cast<float *>(
paddle_mobile::memory::Alloc(sizeof(float) * num_priors));
float *box_height = static_cast<float *>(
paddle_mobile::memory::Alloc(sizeof(float) * num_priors));
int idx = 0;
for (size_t s = 0; s < min_sizes.size(); ++s) {
auto min_size = min_sizes[s];
if (param.MinMaxAspectRatiosOrder()) {
box_width[idx] = box_height[idx] = min_size / 2.;
idx++;
if (max_sizes.size() > 0) {
auto max_size = max_sizes[s];
box_width[idx] = box_height[idx] = sqrt(min_size * max_size) / 2.;
idx++;
}
for (float ar : aspect_ratios) {
if (fabs(ar - 1.) < 1e-6) {
continue;
}
box_width[idx] = min_size * sqrt(ar) / 2.;
box_height[idx] = min_size / sqrt(ar) / 2.;
idx++;
}
} else {
for (float ar : aspect_ratios) {
box_width[idx] = min_size * sqrt(ar) / 2.;
box_height[idx] = min_size / sqrt(ar) / 2.;
idx++;
}
if (!max_sizes.empty()) {
auto max_size = max_sizes[s];
box_width[idx] = box_height[idx] = sqrt(min_size * max_size) / 2.;
idx++;
}
}
}
cl_int status;
auto kernel = this->cl_helper_.KernelAt(0);
auto default_work_size =
this->cl_helper_.DefaultWorkSize(*param.OutputBoxes());
int c_block = default_work_size[0];
int w = default_work_size[1];
int nh = default_work_size[2];
std::vector<int64_t> box_shape({1, 1, 1, num_priors});
framework::DDim ddim = framework::make_ddim(box_shape);
framework::CLTensor box_width_cl_tensor(this->cl_helper_.CLContext(),
this->cl_helper_.CLCommandQueue());
box_width_cl_tensor.Resize(ddim);
cl_mem box_width_Buffer =
box_width_cl_tensor.mutable_with_data<float>(box_width);
framework::CLTensor box_height_cl_tensor(this->cl_helper_.CLContext(),
this->cl_helper_.CLCommandQueue());
box_height_cl_tensor.Resize(ddim);
cl_mem box_height_Buffer =
box_height_cl_tensor.mutable_with_data<float>(box_height);
DLOG << "c_block:" << c_block;
DLOG << "w:" << w;
DLOG << "nh:" << nh;
DLOG << "step_width:" << step_width;
DLOG << "step_height:" << step_height;
DLOG << "offset:" << offset;
DLOG << "img_width:" << img_width;
DLOG << "img_height:" << img_height;
DLOG << "num_priors:" << num_priors;
DLOG << "C:" << C;
status = clSetKernelArg(kernel, 0, sizeof(int), &c_block);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 1, sizeof(int), &w);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 2, sizeof(int), &nh);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 3, sizeof(cl_mem), &box_width_Buffer);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 4, sizeof(cl_mem), &box_height_Buffer);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 5, sizeof(cl_mem), &output_boxes);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 6, sizeof(float), &step_width);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 7, sizeof(float), &step_height);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 8, sizeof(float), &offset);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 9, sizeof(int), &img_width);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 10, sizeof(int), &img_height);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 11, sizeof(int), &num_priors);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 12, sizeof(int), &C);
CL_CHECK_ERRORS(status);
size_t global_work_size[2] = {c_block, nh};
status = clEnqueueNDRangeKernel(this->cl_helper_.CLCommandQueue(), kernel, 2,
NULL, global_work_size, NULL, 0, NULL, NULL);
CL_CHECK_ERRORS(status);
paddle_mobile::memory::Free(box_width);
paddle_mobile::memory::Free(box_height);
}
template class PriorBoxKernel<GPU_CL, float>; template class PriorBoxKernel<GPU_CL, float>;
} // namespace operators } // namespace operators
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef FUSION_DECONVADD_OP
#pragma once
#include "framework/operator.h"
#include "operators/op_param.h"
namespace paddle_mobile {
namespace operators {
using framework::OpKernelBase;
template <typename DeviceType, typename T>
class DeconvAddKernel
: public OpKernelBase<DeviceType, FusionDeconvAddParam<DeviceType>> {
public:
void Compute(const FusionDeconvAddParam<DeviceType> &param);
bool Init(FusionDeconvAddParam<DeviceType> *param);
};
} // namespace operators
} // namespace paddle_mobile
#endif
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef FUSION_DECONVADDRELU_OP
#pragma once
#include "framework/operator.h"
#include "operators/op_param.h"
namespace paddle_mobile {
namespace operators {
using framework::OpKernelBase;
template <typename DeviceType, typename T>
class DeconvAddReluKernel
: public OpKernelBase<DeviceType, FusionDeconvAddReluParam<DeviceType>> {
public:
void Compute(const FusionDeconvAddReluParam<DeviceType> &param);
bool Init(FusionDeconvAddReluParam<DeviceType> *param);
};
} // namespace operators
} // namespace paddle_mobile
#endif
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef FUSION_CONVADD_OP
#include "operators/kernel/conv_add_kernel.h"
namespace paddle_mobile {
namespace operators {
template <>
bool ConvAddKernel<FPGA, float>::Init(FusionConvAddParam<FPGA> *param) {
bool relu_enabled = false;
auto input = const_cast<Tensor *>(param->Input());
const Tensor *bias = param->Bias();
auto bias_ptr = bias->data<float>();
auto filter = const_cast<Tensor *>(param->Filter());
auto out = param->Output();
PADDLE_MOBILE_ENFORCE(out->dims()[1] == bias->dims()[0],
"Output channel should be equal to bias number");
int channel = out->dims()[1];
auto bs_ptr =
(float *)fpga::fpga_malloc(2 * channel * sizeof(float)); // NOLINT
for (int i = 0; i < channel; i++) {
bs_ptr[i + channel] = 1;
bs_ptr[i] = bias_ptr[i];
}
fpga::format_conv_data(filter, out, bs_ptr, param->Groups());
fpga::SplitConvArgs conv_arg = {0};
fpga::fill_split_arg(&conv_arg, input, out, filter, relu_enabled,
param->Groups(), param->Strides()[0],
param->Strides()[1], param->Paddings()[0],
param->Paddings()[1], bs_ptr);
param->SetFpgaArgs(conv_arg);
return true;
}
template <>
void ConvAddKernel<FPGA, float>::Compute(
const FusionConvAddParam<FPGA> &param) {
fpga::ComputeFpgaConv(param.FpgaArgs());
}
} // namespace operators
} // namespace paddle_mobile
#endif
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef FUSION_DECONVADD_OP
#include "operators/kernel/deconv_add_kernel.h"
#include "framework/operator.h"
#include "operators/op_param.h"
namespace paddle_mobile {
namespace operators {
template <>
bool DeconvAddKernel<FPGA, float>::Init(FusionDeconvAddParam<FPGA> *param) {
return true;
}
template <>
void DeconvAddKernel<FPGA, float>::Compute(
const FusionDeconvAddParam<FPGA> &param) {}
} // namespace operators
} // namespace paddle_mobile
#endif
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef FUSION_DECONVADDRELU_OP
#include "operators/kernel/deconv_add_relu_kernel.h"
#include "framework/operator.h"
#include "operators/op_param.h"
namespace paddle_mobile {
namespace operators {
template <>
bool DeconvAddReluKernel<FPGA, float>::Init(
FusionDeconvAddReluParam<FPGA> *param) {
return true;
}
template <>
void DeconvAddReluKernel<FPGA, float>::Compute(
const FusionDeconvAddReluParam<FPGA> &param) {}
} // namespace operators
} // namespace paddle_mobile
#endif
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef SPLIT_OP
#include "operators/kernel/split_kernel.h"
namespace paddle_mobile {
namespace operators {
template <>
bool SplitKernel<FPGA, float>::Init(SplitParam<FPGA>* param) {
return true;
}
template <>
void SplitKernel<FPGA, float>::Compute(const SplitParam<FPGA>& param) {}
} // namespace operators
} // namespace paddle_mobile
#endif
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef TRANSPOSE2_OP
#include "operators/kernel/transpose2_kernel.h"
#include "operators/kernel/central-arm-func/transpose2_arm_func.h"
namespace paddle_mobile {
namespace operators {
template <>
bool Transpose2Kernel<FPGA, float>::Init(Transpose2Param<FPGA> *param) {
return true;
}
template <>
void Transpose2Kernel<FPGA, float>::Compute(
const Transpose2Param<FPGA> &param) {
// Transpose2Compute<float>(param);
}
} // namespace operators
} // namespace paddle_mobile
#endif
...@@ -2234,7 +2234,10 @@ class ConvTransposeParam : public OpParam { ...@@ -2234,7 +2234,10 @@ class ConvTransposeParam : public OpParam {
const Scope &scope) { const Scope &scope) {
filter_ = FilterFrom<GType>(inputs, scope); filter_ = FilterFrom<GType>(inputs, scope);
input_ = InputFrom<GType>(inputs, scope); input_ = InputFrom<GType>(inputs, scope);
output_ = OutputFrom<GType>(outputs, scope); // output_ = OutputFrom<GType>(outputs, scope);
if (outputs.count("Output")) {
output_ = OpParam::OutputFrom<GType>(outputs, scope);
}
strides_ = GetAttr<vector<int>>("strides", attrs); strides_ = GetAttr<vector<int>>("strides", attrs);
paddings_ = GetAttr<vector<int>>("paddings", attrs); paddings_ = GetAttr<vector<int>>("paddings", attrs);
dilations_ = GetAttr<vector<int>>("dilations", attrs); dilations_ = GetAttr<vector<int>>("dilations", attrs);
...@@ -2275,6 +2278,38 @@ class ConvTransposeParam : public OpParam { ...@@ -2275,6 +2278,38 @@ class ConvTransposeParam : public OpParam {
#endif #endif
}; };
#endif #endif
#ifdef FUSION_DECONVADD_OP
template <typename Dtype>
class FusionDeconvAddParam : public ConvTransposeParam<Dtype> {
typedef typename DtypeTensorTrait<Dtype>::gtype GType;
typedef typename DtypeTensorTrait<Dtype>::rtype RType;
public:
FusionDeconvAddParam(const VariableNameMap &inputs,
const VariableNameMap &outputs,
const AttributeMap &attrs, const Scope &scope)
: ConvTransposeParam<Dtype>(inputs, outputs, attrs, scope) {
bias_ = OpParam::InputYFrom<GType>(inputs, scope);
axis_ = OpParam::GetAttr<int>("axis", attrs);
output_ = OpParam::OutFrom<GType>(outputs, scope);
}
RType *Bias() const { return bias_; }
const int &Axis() const { return axis_; }
RType *Output() const { return output_; }
protected:
RType *bias_;
int axis_;
RType *output_;
};
#endif
#ifdef FUSION_DECONVADDRELU_OP
template <typename Dtype>
using FusionDeconvAddReluParam = FusionDeconvAddParam<Dtype>;
#endif
#ifdef FUSION_DECONVRELU_OP #ifdef FUSION_DECONVRELU_OP
template <typename Dtype> template <typename Dtype>
......
...@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and ...@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#ifdef SPLIT_OP #ifdef SPLIT_OP
#include "operators/split_op.h" #include "operators/split_op.h"
#include <vector>
namespace paddle_mobile { namespace paddle_mobile {
namespace operators { namespace operators {
...@@ -83,5 +83,8 @@ namespace ops = paddle_mobile::operators; ...@@ -83,5 +83,8 @@ namespace ops = paddle_mobile::operators;
#ifdef PADDLE_MOBILE_CPU #ifdef PADDLE_MOBILE_CPU
REGISTER_OPERATOR_CPU(split, ops::SplitOp); REGISTER_OPERATOR_CPU(split, ops::SplitOp);
#endif #endif
#ifdef PADDLE_MOBILE_FPGA
REGISTER_OPERATOR_FPGA(split, ops::SplitOp);
#endif
#endif // SPLIT_OP #endif // SPLIT_OP
...@@ -29,7 +29,7 @@ void TanhOp<DeviceType, T>::InferShape() const { ...@@ -29,7 +29,7 @@ void TanhOp<DeviceType, T>::InferShape() const {
namespace ops = paddle_mobile::operators; namespace ops = paddle_mobile::operators;
#ifdef PADDLE_MOBILE_FPGA #ifdef PADDLE_MOBILE_FPGA
REGISTER_OPERATOR_FPGA(Tanh, ops::TanhOp); REGISTER_OPERATOR_FPGA(tanh, ops::TanhOp);
#endif #endif
#endif #endif
...@@ -60,5 +60,8 @@ namespace ops = paddle_mobile::operators; ...@@ -60,5 +60,8 @@ namespace ops = paddle_mobile::operators;
#ifdef PADDLE_MOBILE_CPU #ifdef PADDLE_MOBILE_CPU
REGISTER_OPERATOR_CPU(transpose2, ops::Transpose2Op); REGISTER_OPERATOR_CPU(transpose2, ops::Transpose2Op);
#endif #endif
#ifdef PADDLE_MOBILE_FPGA
REGISTER_OPERATOR_FPGA(transpose2, ops::Transpose2Op);
#endif
#endif // TRANSPOSE_OP #endif // TRANSPOSE_OP
...@@ -66,6 +66,10 @@ list(FIND NET "FPGA_NET_V1" CON) ...@@ -66,6 +66,10 @@ list(FIND NET "FPGA_NET_V1" CON)
if (CON GREATER -1) if (CON GREATER -1)
ADD_EXECUTABLE(test-resnet50 fpga/test_resnet50.cpp test_helper.h test_include.h executor_for_test.h) ADD_EXECUTABLE(test-resnet50 fpga/test_resnet50.cpp test_helper.h test_include.h executor_for_test.h)
target_link_libraries(test-resnet50 paddle-mobile) target_link_libraries(test-resnet50 paddle-mobile)
ADD_EXECUTABLE(test-densebox net/test_densebox_combine.cpp test_helper.h test_include.h executor_for_test.h)
target_link_libraries(test-densebox paddle-mobile)
set(FOUND_MATCH ON) set(FOUND_MATCH ON)
endif () endif ()
...@@ -76,6 +80,10 @@ if (CON GREATER -1) ...@@ -76,6 +80,10 @@ if (CON GREATER -1)
ADD_EXECUTABLE(test-pe fpga/test_pe.cpp) ADD_EXECUTABLE(test-pe fpga/test_pe.cpp)
target_link_libraries(test-pe paddle-mobile) target_link_libraries(test-pe paddle-mobile)
ADD_EXECUTABLE(test-densebox net/test_densebox_combine.cpp test_helper.h test_include.h executor_for_test.h)
target_link_libraries(test-densebox paddle-mobile)
set(FOUND_MATCH ON) set(FOUND_MATCH ON)
endif () endif ()
......
...@@ -127,6 +127,8 @@ endif() ...@@ -127,6 +127,8 @@ endif()
list(FIND NET "FPGA_NET_V2" CON) list(FIND NET "FPGA_NET_V2" CON)
if (CON GREATER -1) if (CON GREATER -1)
message("FPGA_NET_V2 enabled") message("FPGA_NET_V2 enabled")
set(FEED_OP ON)
set(FUSION_CONVADDRELU_OP ON)
set(FUSION_ELEMENTWISEADDRELU_OP ON) set(FUSION_ELEMENTWISEADDRELU_OP ON)
set(FUSION_FC_OP ON) set(FUSION_FC_OP ON)
set(POOL_OP ON) set(POOL_OP ON)
...@@ -135,9 +137,16 @@ if (CON GREATER -1) ...@@ -135,9 +137,16 @@ if (CON GREATER -1)
set(FUSION_CONVBN_OP ON) set(FUSION_CONVBN_OP ON)
set(CONV_TRANSPOSE_OP ON) set(CONV_TRANSPOSE_OP ON)
set(FUSION_DECONVRELU_OP ON) set(FUSION_DECONVRELU_OP ON)
set(SLICE_OP ON) #set(SLICE_OP ON)
set(TANH_OP ON) set(TANH_OP ON)
set(ELEMENTWISEADD_OP ON) set(ELEMENTWISEADD_OP ON)
set(TRANSPOSE2_OP ON)
set(FUSION_CONVADD_OP ON)
set(SPLIT_OP ON)
set(FUSION_DECONVADD_OP ON)
set(FUSION_DECONVADDRELU_OP ON)
set(FOUND_MATCH ON) set(FOUND_MATCH ON)
endif() endif()
...@@ -452,4 +461,10 @@ if (TANH_OP) ...@@ -452,4 +461,10 @@ if (TANH_OP)
endif() endif()
if (FUSION_DECONVRELU_OP) if (FUSION_DECONVRELU_OP)
add_definitions(-DFUSION_DECONVRELU_OP) add_definitions(-DFUSION_DECONVRELU_OP)
endif()
if (FUSION_DECONVADD_OP)
add_definitions(-DFUSION_DECONVADD_OP)
endif()
if (FUSION_DECONVADDRELU_OP)
add_definitions(-DFUSION_DECONVADDRELU_OP)
endif() endif()
\ No newline at end of file
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册