提交 7fbea0e2 编写于 作者: Y yangfei

add some function

上级 49bce554
...@@ -14,110 +14,107 @@ limitations under the License. */ ...@@ -14,110 +14,107 @@ limitations under the License. */
#include "cl_image.h" #include "cl_image.h"
namespace paddle_mobile { namespace paddle_mobile {
namespace framework { namespace framework {
void CLImageToTensor(CLImage *cl_image, Tensor *tensor,cl_command_queue commandQueue){ void CLImageToTensor(CLImage *cl_image, Tensor *tensor,
cl_command_queue commandQueue) {
DDim ddim = cl_image->dims(); DDim ddim = cl_image->dims();
size_t N,C,H,W; size_t N, C, H, W;
if(ddim.size()==4){ if (ddim.size() == 4) {
N = ddim[0]; N = ddim[0];
if(N<0){ if (N < 0) {
N = 1; N = 1;
} }
C = ddim[1]; C = ddim[1];
H = ddim[2]; H = ddim[2];
W = ddim[3]; W = ddim[3];
}else if(ddim.size()==1){ } else if (ddim.size() == 1) {
N = 1; N = 1;
C = ddim[0]; C = ddim[0];
H = 1; H = 1;
W = 1; W = 1;
} }
size_t width = W * ((C + 3) / 4); size_t width = W * ((C + 3) / 4);
size_t height = H * N; size_t height = H * N;
float *p = tensor->data<float>(); float *p = tensor->data<float>();
half imageData[width * height * 4]; half imageData[width * height * 4];
cl_int err; cl_int err;
cl_mem image = cl_image->GetCLImage(); cl_mem image = cl_image->GetCLImage();
size_t origin[3] = {0,0,0}; size_t origin[3] = {0, 0, 0};
size_t region[3] = {width,height,1}; size_t region[3] = {width, height, 1};
err = clEnqueueReadImage(commandQueue,image,CL_TRUE,origin,region,0,0,imageData,0,NULL,NULL); err = clEnqueueReadImage(commandQueue, image, CL_TRUE, origin, region, 0, 0,
size_t i0 = 0; imageData, 0, NULL, NULL);
for (int n = 0; n < N; n++) { size_t i0 = 0;
for (int c = 0; c < C; c++) { for (int n = 0; n < N; n++) {
size_t i1 = i0; for (int c = 0; c < C; c++) {
for (int h = 0; h < H; h++) { size_t i1 = i0;
size_t i2 = (i1<<2) + c % 4; for (int h = 0; h < H; h++) {
for (int w = 0; w < W; w++) { size_t i2 = (i1 << 2) + c % 4;
*p = half2float(imageData[i2]); for (int w = 0; w < W; w++) {
i2 += 4; *p = half2float(imageData[i2]);
p++; i2 += 4;
} p++;
i1 += width;
}
}
i0 += width * H;
}
if (err != CL_SUCCESS) {
// TODO: error handling
}
} }
void TensorToCLImage(const Tensor *tensor, CLImage *cl_image,cl_command_queue commandQueue){ i1 += width;
}
DDim ddim = cl_image->dims(); }
size_t N,C,H,W; i0 += width * H;
if(ddim.size()==4){ }
N = ddim[0];
if(N<0){
N = 1;
}
C = ddim[1];
H = ddim[2];
W = ddim[3];
}else if(ddim.size()==1){
N = 1;
C = ddim[0];
H = 1;
W = 1;
}
size_t width = W * ((C + 3) / 4);
size_t height = H * N;
const float *p = tensor->data<float>();
half imageData[width * height * 4];
cl_mem image = cl_image->GetCLImage();
size_t origin[3] = {0,0,0};
size_t region[3] = {width,height,1};
cl_int err;
err = clEnqueueReadImage(commandQueue,image,CL_TRUE,origin,region,0,0,imageData,0,NULL,NULL);
if (err != CL_SUCCESS) {
// TODO: error handling
}
size_t i0 = 0;
for (int n = 0; n < N; n++) {
for (int c = 0; c < C; c++) {
size_t i1 = i0;
for (int h = 0; h < H; h++) {
size_t i2 = (i1<<2) + c % 4;
for (int w = 0; w < W; w++) {
imageData[i2] = float2half(*p);
i2 += 4;
p++;
}
i1 += width;
}
}
i0 += width * H;
}
if (err != CL_SUCCESS) {
// TODO: error handling
}
}
void TensorToCLImage(const Tensor *tensor, CLImage *cl_image,
cl_command_queue commandQueue) {
DDim ddim = cl_image->dims();
size_t N, C, H, W;
if (ddim.size() == 4) {
N = ddim[0];
if (N < 0) {
N = 1;
}
C = ddim[1];
H = ddim[2];
W = ddim[3];
} else if (ddim.size() == 1) {
N = 1;
C = ddim[0];
H = 1;
W = 1;
}
size_t width = W * ((C + 3) / 4);
size_t height = H * N;
const float *p = tensor->data<float>();
half imageData[width * height * 4];
cl_mem image = cl_image->GetCLImage();
size_t origin[3] = {0, 0, 0};
size_t region[3] = {width, height, 1};
cl_int err;
err = clEnqueueReadImage(commandQueue, image, CL_TRUE, origin, region, 0, 0,
imageData, 0, NULL, NULL);
if (err != CL_SUCCESS) {
// TODO: error handling
}
size_t i0 = 0;
for (int n = 0; n < N; n++) {
for (int c = 0; c < C; c++) {
size_t i1 = i0;
for (int h = 0; h < H; h++) {
size_t i2 = (i1 << 2) + c % 4;
for (int w = 0; w < W; w++) {
imageData[i2] = float2half(*p);
i2 += 4;
p++;
} }
i1 += width;
}
} }
i0 += width * H;
}
} }
} // namespace framework
} // namespace paddle_mobile
...@@ -30,6 +30,20 @@ class CLImage { ...@@ -30,6 +30,20 @@ class CLImage {
void Init(cl_context context, float *tensorInput, DDim ddim) { void Init(cl_context context, float *tensorInput, DDim ddim) {
tensor_dims_ = ddim; tensor_dims_ = ddim;
if (tensorInput) {
tensor_input_ = tensorInput;
} else {
int numel = 1;
for (int i = 0; i < ddim.size(); i++) {
numel *= ddim[i];
}
tensor_input_ = static_cast<float *>(
paddle_mobile::memory::Alloc(sizeof(float) * numel));
for (int i = 0; i < numel; i++) {
tensor_input_[i] = 0;
}
}
cl_image_format cf = {.image_channel_order = CL_RGBA, cl_image_format cf = {.image_channel_order = CL_RGBA,
.image_channel_data_type = CL_HALF_FLOAT}; .image_channel_data_type = CL_HALF_FLOAT};
// NCHW -> [W * (C+3)/4, H * N] // NCHW -> [W * (C+3)/4, H * N]
...@@ -65,9 +79,9 @@ class CLImage { ...@@ -65,9 +79,9 @@ class CLImage {
std::unique_ptr<half_t[]> imageData{}; std::unique_ptr<half_t[]> imageData{};
int count = 0; int count = 0;
if (tensorInput != nullptr) { imageData.reset(new half_t[width * height * 4]);
imageData.reset(new half_t[width * height * 4]); if (tensor_input_ != nullptr) {
float *p = tensorInput; float *p = tensor_input_;
size_t i0 = 0; size_t i0 = 0;
for (int n = 0; n < N; n++) { for (int n = 0; n < N; n++) {
for (int c = 0; c < C; c++) { for (int c = 0; c < C; c++) {
...@@ -75,11 +89,13 @@ class CLImage { ...@@ -75,11 +89,13 @@ class CLImage {
for (int h = 0; h < H; h++) { for (int h = 0; h < H; h++) {
size_t i2 = (i1 << 2) + c % 4; size_t i2 = (i1 << 2) + c % 4;
for (int w = 0; w < W; w++) { for (int w = 0; w < W; w++) {
if (i2 >= width * height * 4) { // if (i2 >= width * height * 4) {
printf("%d > %d ----> %d, %d, %d, %d --- %d, %d, %d\n", i2, // printf("%d > %d ----> %d, %d, %d, %d --- %d, %d,
width * height * 4, n, c, h, w, i0, i1, i2); // %d\n", i2,
} // width * height * 4, n, c, h, w, i0, i1,
assert(i2 < width * height * 4); // i2);
// }
// assert(i2 < width * height * 4);
imageData[i2] = float2half(*p); imageData[i2] = float2half(*p);
i2 += 4; i2 += 4;
...@@ -153,9 +169,11 @@ class CLImage { ...@@ -153,9 +169,11 @@ class CLImage {
cl_context context_; cl_context context_;
}; };
void TensorToCLImage(Tensor *tensor, CLImage *image); void TensorToCLImage(Tensor *tensor, CLImage *image,
cl_command_queue commandQueue);
void CLImageToTensor(CLImage *image, Tensor *tensor); void CLImageToTensor(CLImage *image, Tensor *tensor,
cl_command_queue commandQueue);
} // namespace framework } // namespace framework
} // namespace paddle_mobile } // namespace paddle_mobile
...@@ -56,7 +56,8 @@ class CLScope { ...@@ -56,7 +56,8 @@ class CLScope {
auto program = CLEngine::Instance()->CreateProgramWith( auto program = CLEngine::Instance()->CreateProgramWith(
context_.get(), "./cl_kernel/" + file_name); context_.get(), "./cl_kernel/" + file_name);
status_ = clBuildProgram(program.get(), 0, 0, "-cl-fast-relaxed-math", 0, 0); status_ =
clBuildProgram(program.get(), 0, 0, "-cl-fast-relaxed-math", 0, 0);
CL_CHECK_ERRORS(status_); CL_CHECK_ERRORS(status_);
programs_[file_name] = std::move(program); programs_[file_name] = std::move(program);
......
...@@ -931,7 +931,7 @@ void Executor<GPU_CL, Precision::FP32>::InitMemory() { ...@@ -931,7 +931,7 @@ void Executor<GPU_CL, Precision::FP32>::InitMemory() {
cl_image->Init(context, tensorInput, ddim); cl_image->Init(context, tensorInput, ddim);
delete origin_data; delete origin_data;
paddle_mobile::memory::Free(tensorInput); // paddle_mobile::memory::Free(tensorInput);
} else { } else {
if (var_desc->Type() == framework::VARTYPE_TYPE_LOD_TENSOR) { if (var_desc->Type() == framework::VARTYPE_TYPE_LOD_TENSOR) {
auto cl_image = var->template GetMutable<framework::CLImage>(); auto cl_image = var->template GetMutable<framework::CLImage>();
......
...@@ -72,13 +72,16 @@ void OperatorBase<Dtype>::Run() { ...@@ -72,13 +72,16 @@ void OperatorBase<Dtype>::Run() {
if (tensor) DLOG << type_ << " input- " << key << "=" << *tensor; if (tensor) DLOG << type_ << " input- " << key << "=" << *tensor;
} else { } else {
CLImage *cl_image = vari->template GetMutable<framework::CLImage>(); CLImage *cl_image = vari->template GetMutable<framework::CLImage>();
// cl_command_queue commandQueue = // cl_command_queue commandQueue =
// scope_->GetCLScpoe()->CommandQueue(); Tensor *tmp ; // scope_->GetCLScpoe()->CommandQueue(); Tensor
// CLImageToTensor(cl_image,tmp,commandQueue); // *tmp ;
// tmp->Resize(cl_image->dims()); // CLImageToTensor(cl_image,tmp,commandQueue);
// tmp->Resize(cl_image->dims());
const float *input = cl_image->data<float>();
if (cl_image) { if (cl_image) {
// DLOG<<type_<<" input- "<<key<<"="<<*tmp;
DLOG << type_ << " input- " << key << "=" << cl_image->dims(); DLOG << type_ << " input- " << key << "=" << cl_image->dims();
// if(input)
// DLOG<<type_<<" input- "<<key<<"="<<*input;
} }
} }
...@@ -95,15 +98,24 @@ void OperatorBase<Dtype>::Run() { ...@@ -95,15 +98,24 @@ void OperatorBase<Dtype>::Run() {
auto vari = scope_->FindVar(var_vec_out[i]); auto vari = scope_->FindVar(var_vec_out[i]);
if (vari->IsInitialized()) { if (vari->IsInitialized()) {
#ifdef PADDLE_MOBILE_CL #ifdef PADDLE_MOBILE_CL
CLImage *cl_image = vari->template GetMutable<framework::CLImage>(); if (type_ == "fetch") {
// cl_command_queue commandQueue = Tensor *tensor = vari->template GetMutable<framework::LoDTensor>();
// scope_->GetCLScpoe()->CommandQueue(); Tensor *tmp ; if (tensor)
// CLImageToTensor(cl_image,tmp,commandQueue); DLOG << type_ << " output- " << key << "=" << tensor->dims();
// tmp->Resize(cl_image->dims()); } else {
if (cl_image) { CLImage *cl_image = vari->template GetMutable<framework::CLImage>();
// DLOG<<type_<<" output- "<<key<<"="<<*tmp; // cl_command_queue commandQueue =
DLOG << type_ << " output- " << key << "=" << cl_image->dims(); // scope_->GetCLScpoe()->CommandQueue(); Tensor *tmp ;
// CLImageToTensor(cl_image,tmp,commandQueue);
// tmp->Resize(cl_image->dims());
if (cl_image) {
const float *output = cl_image->data<float>();
DLOG << type_ << " output- " << key << "=" << cl_image->dims();
// if(output)
// DLOG<<type_<<" output- "<<key<<"="<<*output;
}
} }
#else #else
Tensor *tensor = vari->template GetMutable<framework::LoDTensor>(); Tensor *tensor = vari->template GetMutable<framework::LoDTensor>();
if (tensor) DLOG << type_ << " output- " << key << "=" << *tensor; if (tensor) DLOG << type_ << " output- " << key << "=" << *tensor;
......
...@@ -98,8 +98,8 @@ class FeedOp : public framework::OperatorBase<DeviceType> { ...@@ -98,8 +98,8 @@ class FeedOp : public framework::OperatorBase<DeviceType> {
void Init() {} void Init() {}
void RunImpl() { void RunImpl() {
param_.Out()->ShareDataWith(*param_.InputX()); param_.Out()->ShareDataWith(*param_.InputX());
param_.Out()->set_lod(param_.InputX()->lod()); param_.Out()->set_lod(param_.InputX()->lod());
} }
protected: protected:
......
...@@ -18,9 +18,10 @@ limitations under the License. */ ...@@ -18,9 +18,10 @@ limitations under the License. */
inline hafl4 activation(half4 in inline hafl4 activation(half4 in
#ifdef PRELU #ifdef PRELU
,half4 prelu_alpha ,
half4 prelu_alpha
#endif #endif
) { ) {
half4 output; half4 output;
#ifdef PRELU #ifdef PRELU
output = select(prelu_alpha * in, in, in >= (half4)0.0); output = select(prelu_alpha * in, in, in >= (half4)0.0);
...@@ -31,4 +32,3 @@ inline hafl4 activation(half4 in ...@@ -31,4 +32,3 @@ inline hafl4 activation(half4 in
#endif #endif
return output; return output;
} }
...@@ -24,9 +24,9 @@ template <> ...@@ -24,9 +24,9 @@ template <>
bool DepthwiseConvKernel<GPU_CL, float>::Init(ConvParam<GPU_CL> *param) { bool DepthwiseConvKernel<GPU_CL, float>::Init(ConvParam<GPU_CL> *param) {
DLOG << " depthwise conv kernel init begin "; DLOG << " depthwise conv kernel init begin ";
PADDLE_MOBILE_ENFORCE( PADDLE_MOBILE_ENFORCE(
param->Filter()->dims()[2] == param->Filter()->dims()[3] && param->Filter()->dims()[2] == param->Filter()->dims()[3] &&
param->Paddings()[0] == param->Paddings()[1], param->Paddings()[0] == param->Paddings()[1],
"need equal"); "need equal");
int offset = static_cast<int>(param->Filter()->dims()[2]) / 2 - int offset = static_cast<int>(param->Filter()->dims()[2]) / 2 -
static_cast<int>(param->Paddings()[1]); static_cast<int>(param->Paddings()[1]);
param->SetOffset(offset); param->SetOffset(offset);
...@@ -36,7 +36,8 @@ bool DepthwiseConvKernel<GPU_CL, float>::Init(ConvParam<GPU_CL> *param) { ...@@ -36,7 +36,8 @@ bool DepthwiseConvKernel<GPU_CL, float>::Init(ConvParam<GPU_CL> *param) {
} }
template <> template <>
void DepthwiseConvKernel<GPU_CL, float>::Compute(const ConvParam<GPU_CL> &param) { void DepthwiseConvKernel<GPU_CL, float>::Compute(
const ConvParam<GPU_CL> &param) {
auto kernel = this->cl_helper_.KernelAt(0); auto kernel = this->cl_helper_.KernelAt(0);
auto default_work_size = this->cl_helper_.DefaultWorkSize(*param.Output()); auto default_work_size = this->cl_helper_.DefaultWorkSize(*param.Output());
int c_block = default_work_size[0]; int c_block = default_work_size[0];
...@@ -78,4 +79,4 @@ template class DepthwiseConvKernel<GPU_CL, float>; ...@@ -78,4 +79,4 @@ template class DepthwiseConvKernel<GPU_CL, float>;
} // namespace operators } // namespace operators
} // namespace paddle_mobile } // namespace paddle_mobile
#endif #endif
\ No newline at end of file
...@@ -12,42 +12,43 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,42 +12,43 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "common/log.h"
#include "operators/kernel/feed_kernel.h" #include "operators/kernel/feed_kernel.h"
#include "common/log.h"
namespace paddle_mobile { namespace paddle_mobile {
namespace operators { namespace operators {
template <> template <>
bool FeedKernel<GPU_CL, float>::Init(FeedParam<GPU_CL> *param) { bool FeedKernel<GPU_CL, float>::Init(FeedParam<GPU_CL> *param) {
DLOG<<"Init feed"; DLOG << "Init feed";
this->cl_helper_.AddKernel("feed", "feed_kernel.cl"); this->cl_helper_.AddKernel("feed", "feed_kernel.cl");
return true; return true;
} }
template <> template <>
void FeedKernel<GPU_CL, float>::Compute(const FeedParam<GPU_CL> &param) { void FeedKernel<GPU_CL, float>::Compute(const FeedParam<GPU_CL> &param) {
auto kernel = this->cl_helper_.KernelAt(0);
DLOG<<"feed_kernel"; cl_int status;
auto kernel = this->cl_helper_.KernelAt(0); auto output = param.Out();
cl_int status; const Tensor *input = param.InputX();
auto output = param.Out(); const float *input_data = nullptr;
auto input = param.InputX(); input_data = input->data<float>();
const float *input_data = input->data<float>();
cl_mem cl_image = output->GetCLImage(); cl_mem cl_image = output->GetCLImage();
int height = output->dims()[2]; int height = output->dims()[2];
int width = output->dims()[3]; int width = output->dims()[3];
status = clSetKernelArg(kernel,0, sizeof(cl_mem),&input_data); DLOG << output->dims();
status = clSetKernelArg(kernel,0, sizeof(cl_mem),&cl_image); status = clSetKernelArg(kernel, 0, sizeof(cl_mem), &input_data);
status = clSetKernelArg(kernel,0, sizeof(cl_mem),&width); status = clSetKernelArg(kernel, 0, sizeof(cl_mem), &cl_image);
status = clSetKernelArg(kernel,0, sizeof(cl_mem),&height); status = clSetKernelArg(kernel, 0, sizeof(cl_mem), &width);
status = clSetKernelArg(kernel, 0, sizeof(cl_mem), &height);
size_t global_work_size[2] = {height,width};
clEnqueueNDRangeKernel(this->cl_helper_.CLCommandQueue(), kernel, 2, NULL, global_work_size, NULL, 0, NULL, NULL); size_t global_work_size[2] = {height, width};
} clEnqueueNDRangeKernel(this->cl_helper_.CLCommandQueue(), kernel, 2, NULL,
global_work_size, NULL, 0, NULL, NULL);
template class FeedKernel<GPU_CL, float>; }
} // namespace operators template class FeedKernel<GPU_CL, float>;
} // namespace operators
} // namespace paddle_mobile } // namespace paddle_mobile
...@@ -18,15 +18,15 @@ limitations under the License. */ ...@@ -18,15 +18,15 @@ limitations under the License. */
#include "operators/op_param.h" #include "operators/op_param.h"
namespace paddle_mobile { namespace paddle_mobile {
namespace operators { namespace operators {
using namespace framework; using namespace framework;
template <typename DeviceType, typename T> template <typename DeviceType, typename T>
class FeedKernel class FeedKernel
: public framework::OpKernelBase<DeviceType, FeedParam<DeviceType>>{ : public framework::OpKernelBase<DeviceType, FeedParam<DeviceType>> {
public: public:
void Compute(const FeedParam<DeviceType> &param); void Compute(const FeedParam<DeviceType> &param);
bool Init(FeedParam<DeviceType> *param); bool Init(FeedParam<DeviceType> *param);
}; };
} // namespace operators } // namespace operators
} // namespace paddle_mobile } // namespace paddle_mobile
...@@ -936,14 +936,14 @@ class FetchParam : public OpParam { ...@@ -936,14 +936,14 @@ class FetchParam : public OpParam {
FetchParam(const VariableNameMap &inputs, const VariableNameMap &outputs, FetchParam(const VariableNameMap &inputs, const VariableNameMap &outputs,
const AttributeMap &attrs, const Scope &scope) { const AttributeMap &attrs, const Scope &scope) {
input_x_ = InputXFrom<GType>(inputs, scope); input_x_ = InputXFrom<GType>(inputs, scope);
out_ = OutFrom<GType>(outputs, scope); out_ = OutFrom<LoDTensor>(outputs, scope);
} }
const RType *InputX() const { return input_x_; } const RType *InputX() const { return input_x_; }
RType *Out() const { return out_; } Tensor *Out() const { return out_; }
private: private:
RType *input_x_; RType *input_x_;
RType *out_; Tensor *out_;
}; };
#ifdef TRANSPOSE_OP #ifdef TRANSPOSE_OP
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册