提交 f7b97680 编写于 作者: Z zp7 提交者: Yanzhan Yang

fix reshape2&transpose2 gpu infershape (#1778)

上级 51f8d740
...@@ -122,6 +122,8 @@ Print &operator<<(Print &printer, const CLImage &cl_image) { ...@@ -122,6 +122,8 @@ Print &operator<<(Print &printer, const CLImage &cl_image) {
CL_CHECK_ERRORS(err); CL_CHECK_ERRORS(err);
PADDLE_MOBILE_ENFORCE(cl_image.numel() != 0,
"cl_image numel should not be 0 ");
float *tensor_data = new float[cl_image.numel()]; float *tensor_data = new float[cl_image.numel()];
auto converter = cl_image.Converter(); auto converter = cl_image.Converter();
converter->ImageToNCHW(image_data, tensor_data, cl_image.ImageDims(), converter->ImageToNCHW(image_data, tensor_data, cl_image.ImageDims(),
......
...@@ -391,6 +391,8 @@ void CLImageConverterDWBlock::ImageToNCHW(half_t *image, float *tensor, ...@@ -391,6 +391,8 @@ void CLImageConverterDWBlock::ImageToNCHW(half_t *image, float *tensor,
const DDim &CLImageConverterNormal::InitImageDimInfoWith( const DDim &CLImageConverterNormal::InitImageDimInfoWith(
const DDim &tensor_dim) { const DDim &tensor_dim) {
PADDLE_MOBILE_ENFORCE(tensor_dim.size() <= 4 && tensor_dim.size() > 0,
"tensor dim is not support ");
size_t new_dims[] = {1, 1, 1, 1}; size_t new_dims[] = {1, 1, 1, 1};
for (int j = 0; j < tensor_dim.size(); ++j) { for (int j = 0; j < tensor_dim.size(); ++j) {
new_dims[4 - tensor_dim.size() + j] = tensor_dim[j]; new_dims[4 - tensor_dim.size() + j] = tensor_dim[j];
......
...@@ -1027,7 +1027,7 @@ void Executor<GPU_CL, float>::InitCombineMemory() { ...@@ -1027,7 +1027,7 @@ void Executor<GPU_CL, float>::InitCombineMemory() {
bool shouldResize = true; bool shouldResize = true;
if (ddim.size() > 4) { if (ddim.size() > 4) {
for (int i = 0; i < ddim.size() - 4; ++i) { for (int i = 0; i < ddim.size() - 4; ++i) {
if (ddim[i] != 0) { if (ddim[i] != 0 && ddim[i] != 1) {
shouldResize = false; shouldResize = false;
break; break;
} }
......
...@@ -75,6 +75,9 @@ void Reshape2Op<Dtype, T>::InferShape() const { ...@@ -75,6 +75,9 @@ void Reshape2Op<Dtype, T>::InferShape() const {
xshape_dims[i + 1] = input_x_dims[i]; xshape_dims[i + 1] = input_x_dims[i];
} }
this->param_.OutputXShape()->Resize(framework::make_ddim(xshape_dims)); this->param_.OutputXShape()->Resize(framework::make_ddim(xshape_dims));
#ifdef PADDLE_MOBILE_CL
this->param_.OutputXShape()->Resize(input_x_dims);
#endif
} }
} // namespace operators } // namespace operators
......
...@@ -100,6 +100,9 @@ void Transpose2Op<Dtype, T>::InferShape() const { ...@@ -100,6 +100,9 @@ void Transpose2Op<Dtype, T>::InferShape() const {
xshape_dims[i + 1] = input_x_dims[i]; xshape_dims[i + 1] = input_x_dims[i];
} }
this->param_.OutputXShape()->Resize(framework::make_ddim(xshape_dims)); this->param_.OutputXShape()->Resize(framework::make_ddim(xshape_dims));
#ifdef PADDLE_MOBILE_CL
this->param_.OutputXShape()->Resize(input_x_dims);
#endif
} }
} // namespace operators } // namespace operators
......
...@@ -155,7 +155,7 @@ def load_feed_kv(): ...@@ -155,7 +155,7 @@ def load_feed_kv():
expected_len = 1 expected_len = 1
for dim in feed_shape: for dim in feed_shape:
expected_len *= dim expected_len *= dim
if len(data) != expected_len: if len(np.atleast_1d(data)) != expected_len:
return None return None
data = data.reshape(feed_shape).astype("float32") data = data.reshape(feed_shape).astype("float32")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册