未验证 提交 549ebc0a 编写于 作者: Y Yanzhan Yang 提交者: GitHub

restore fetch utility for opencl (#1770)

上级 ff4fb9b2
...@@ -460,6 +460,20 @@ std::shared_ptr<LoDTensor> Executor<Device, T>::GetOutput( ...@@ -460,6 +460,20 @@ std::shared_ptr<LoDTensor> Executor<Device, T>::GetOutput(
} }
} }
#ifdef PADDLE_MOBILE_CL
template <typename Device, typename T>
const CLImage *Executor<Device, T>::GetOutputImage(
const std::string &var_name) {
auto var = program_.scope->FindVar(var_name);
if (var->IsInitialized() && var->template IsType<framework::CLImage>()) {
const CLImage *cl_image = var->template Get<framework::CLImage>();
return cl_image;
} else {
return nullptr;
}
}
#endif
template <typename Device, typename T> template <typename Device, typename T>
PMStatus Executor<Device, T>::Predict() { PMStatus Executor<Device, T>::Predict() {
#if _OPENMP #if _OPENMP
......
...@@ -53,6 +53,9 @@ class Executor { ...@@ -53,6 +53,9 @@ class Executor {
void SetInput(const LoDTensor &input, const std::string &var_name); void SetInput(const LoDTensor &input, const std::string &var_name);
std::shared_ptr<LoDTensor> GetOutput(const std::string &var_name); std::shared_ptr<LoDTensor> GetOutput(const std::string &var_name);
#ifdef PADDLE_MOBILE_CL
const CLImage *GetOutputImage(const std::string &var_name);
#endif
void FeedTensorData(const std::vector<framework::Tensor> &v); void FeedTensorData(const std::vector<framework::Tensor> &v);
void GetTensorResults(std::vector<framework::Tensor *> *v); void GetTensorResults(std::vector<framework::Tensor *> *v);
......
...@@ -170,6 +170,14 @@ LoDTensorPtr PaddleMobile<Device, T>::Fetch(const std::string &var_name) { ...@@ -170,6 +170,14 @@ LoDTensorPtr PaddleMobile<Device, T>::Fetch(const std::string &var_name) {
return executor_->GetOutput(var_name); return executor_->GetOutput(var_name);
} }
#ifdef PADDLE_MOBILE_CL
template <typename Device, typename T>
const framework::CLImage *PaddleMobile<Device, T>::FetchImage(
const std::string &var_name) {
return executor_->GetOutputImage(var_name);
}
#endif
template <typename Device, typename T> template <typename Device, typename T>
void PaddleMobile<Device, T>::Clear() { void PaddleMobile<Device, T>::Clear() {
executor_ = nullptr; executor_ = nullptr;
......
...@@ -74,6 +74,9 @@ class PaddleMobile { ...@@ -74,6 +74,9 @@ class PaddleMobile {
typedef std::shared_ptr<framework::LoDTensor> LoDTensorPtr; typedef std::shared_ptr<framework::LoDTensor> LoDTensorPtr;
LoDTensorPtr Fetch(const std::string &var_name); LoDTensorPtr Fetch(const std::string &var_name);
#ifdef PADDLE_MOBILE_CL
const framework::CLImage *FetchImage(const std::string &var_name);
#endif
LoDTensorPtr Fetch() { return Fetch("fetch"); } LoDTensorPtr Fetch() { return Fetch("fetch"); }
......
...@@ -167,6 +167,44 @@ void test(int argc, char *argv[]) { ...@@ -167,6 +167,44 @@ void test(int argc, char *argv[]) {
paddle_mobile.Feed(var_names[0], input_tensor); paddle_mobile.Feed(var_names[0], input_tensor);
paddle_mobile.Predict(); paddle_mobile.Predict();
} }
#ifdef PADDLE_MOBILE_CL
for (auto var_name : var_names) {
auto cl_image = paddle_mobile.FetchImage(var_name);
auto len = cl_image->numel();
if (len == 0) {
continue;
}
int width = cl_image->ImageDims()[0];
int height = cl_image->ImageDims()[1];
paddle_mobile::framework::half_t *image_data =
new paddle_mobile::framework::half_t[height * width * 4];
cl_int err;
cl_mem image = cl_image->GetCLImage();
size_t origin[3] = {0, 0, 0};
size_t region[3] = {width, height, 1};
err = clEnqueueReadImage(cl_image->CommandQueue(), image, CL_TRUE, origin,
region, 0, 0, image_data, 0, NULL, NULL);
CL_CHECK_ERRORS(err);
float *tensor_data = new float[cl_image->numel()];
auto converter = cl_image->Converter();
converter->ImageToNCHW(image_data, tensor_data, cl_image->ImageDims(),
cl_image->dims());
auto data = tensor_data;
std::string sample = "";
if (!is_sample_step) {
sample_step = len / sample_num;
}
if (sample_step <= 0) {
sample_step = 1;
}
for (int i = 0; i < len; i += sample_step) {
sample += " " + std::to_string(data[i]);
}
std::cout << "auto-test"
<< " var " << var_name << sample << std::endl;
}
#else
for (auto var_name : var_names) { for (auto var_name : var_names) {
auto out = paddle_mobile.Fetch(var_name); auto out = paddle_mobile.Fetch(var_name);
auto len = out->numel(); auto len = out->numel();
...@@ -206,6 +244,7 @@ void test(int argc, char *argv[]) { ...@@ -206,6 +244,7 @@ void test(int argc, char *argv[]) {
<< " var " << var_name << sample << std::endl; << " var " << var_name << sample << std::endl;
} }
} }
#endif
std::cout << std::endl; std::cout << std::endl;
} }
} }
...@@ -33,3 +33,7 @@ limitations under the License. */ ...@@ -33,3 +33,7 @@ limitations under the License. */
#include "framework/tensor.h" #include "framework/tensor.h"
#include "framework/variable.h" #include "framework/variable.h"
#include "io/paddle_mobile.h" #include "io/paddle_mobile.h"
#ifdef PADDLE_MOBILE_CL
#include "framework/cl/cl_image.h"
#endif
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册