Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
b22f6d69
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
b22f6d69
编写于
4月 29, 2021
作者:
L
LielinJiang
提交者:
GitHub
4月 29, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add op read_file and decode_jpeg (#32564)
* add op read_file and decode_jpeg
上级
7a73692b
变更
11
隐藏空白更改
内联
并排
Showing
11 changed file
with
607 addition
and
2 deletion
+607
-2
cmake/operators.cmake
cmake/operators.cmake
+1
-0
paddle/fluid/operators/decode_jpeg_op.cc
paddle/fluid/operators/decode_jpeg_op.cc
+114
-0
paddle/fluid/operators/decode_jpeg_op.cu
paddle/fluid/operators/decode_jpeg_op.cu
+138
-0
paddle/fluid/operators/read_file_op.cc
paddle/fluid/operators/read_file_op.cc
+92
-0
paddle/fluid/platform/dynload/CMakeLists.txt
paddle/fluid/platform/dynload/CMakeLists.txt
+1
-1
paddle/fluid/platform/dynload/dynamic_loader.cc
paddle/fluid/platform/dynload/dynamic_loader.cc
+17
-0
paddle/fluid/platform/dynload/dynamic_loader.h
paddle/fluid/platform/dynload/dynamic_loader.h
+1
-0
paddle/fluid/platform/dynload/nvjpeg.cc
paddle/fluid/platform/dynload/nvjpeg.cc
+27
-0
paddle/fluid/platform/dynload/nvjpeg.h
paddle/fluid/platform/dynload/nvjpeg.h
+53
-0
python/paddle/tests/test_read_file.py
python/paddle/tests/test_read_file.py
+67
-0
python/paddle/vision/ops.py
python/paddle/vision/ops.py
+96
-1
未找到文件。
cmake/operators.cmake
浏览文件 @
b22f6d69
...
...
@@ -182,6 +182,7 @@ function(op_library TARGET)
list
(
REMOVE_ITEM hip_srcs
"cholesky_op.cu"
)
list
(
REMOVE_ITEM hip_srcs
"correlation_op.cu"
)
list
(
REMOVE_ITEM hip_srcs
"multinomial_op.cu"
)
list
(
REMOVE_ITEM hip_srcs
"decode_jpeg_op.cu"
)
hip_library
(
${
TARGET
}
SRCS
${
cc_srcs
}
${
hip_cc_srcs
}
${
miopen_cu_cc_srcs
}
${
miopen_cu_srcs
}
${
mkldnn_cc_srcs
}
${
hip_srcs
}
DEPS
${
op_library_DEPS
}
${
op_common_deps
}
)
else
()
...
...
paddle/fluid/operators/decode_jpeg_op.cc
0 → 100644
浏览文件 @
b22f6d69
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <fstream>
#include <string>
#include <vector>
#include "paddle/fluid/framework/generator.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/platform/dynload/nvjpeg.h"
#include "paddle/fluid/platform/enforce.h"
namespace
paddle
{
namespace
operators
{
template
<
typename
T
>
class
CPUDecodeJpegKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
// TODO(LieLinJiang): add cpu implement.
PADDLE_THROW
(
platform
::
errors
::
Unimplemented
(
"DecodeJpeg op only supports GPU now."
));
}
};
class
DecodeJpegOp
:
public
framework
::
OperatorWithKernel
{
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"X"
),
"Input"
,
"X"
,
"DecodeJpeg"
);
OP_INOUT_CHECK
(
ctx
->
HasOutput
(
"Out"
),
"Output"
,
"Out"
,
"DecodeJpeg"
);
auto
mode
=
ctx
->
Attrs
().
Get
<
std
::
string
>
(
"mode"
);
std
::
vector
<
int
>
out_dims
;
if
(
mode
==
"unchanged"
)
{
out_dims
=
{
-
1
,
-
1
,
-
1
};
}
else
if
(
mode
==
"gray"
)
{
out_dims
=
{
1
,
-
1
,
-
1
};
}
else
if
(
mode
==
"rgb"
)
{
out_dims
=
{
3
,
-
1
,
-
1
};
}
else
{
PADDLE_THROW
(
platform
::
errors
::
Fatal
(
"The provided mode is not supported for JPEG files on GPU: "
,
mode
));
}
ctx
->
SetOutputDim
(
"Out"
,
framework
::
make_ddim
(
out_dims
));
}
protected:
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
OpKernelType
(
OperatorWithKernel
::
IndicateVarDataType
(
ctx
,
"X"
),
ctx
.
GetPlace
());
}
framework
::
OpKernelType
GetKernelTypeForVar
(
const
std
::
string
&
var_name
,
const
framework
::
Tensor
&
tensor
,
const
framework
::
OpKernelType
&
expected_kernel_type
)
const
{
if
(
var_name
==
"X"
)
{
return
expected_kernel_type
;
}
return
framework
::
OpKernelType
(
tensor
.
type
(),
tensor
.
place
(),
tensor
.
layout
());
}
};
class
DecodeJpegOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
public:
void
Make
()
override
{
AddInput
(
"X"
,
"A one dimensional uint8 tensor containing the raw bytes "
"of the JPEG image. It is a tensor with rank 1."
);
AddOutput
(
"Out"
,
"The output tensor of DecodeJpeg op"
);
AddComment
(
R"DOC(
This operator decodes a JPEG image into a 3 dimensional RGB Tensor
or 1 dimensional Gray Tensor. Optionally converts the image to the
desired format. The values of the output tensor are uint8 between 0
and 255.
)DOC"
);
AddAttr
<
std
::
string
>
(
"mode"
,
"(string, default
\"
unchanged
\"
), The read mode used "
"for optionally converting the image, can be
\"
unchanged
\"
"
",
\"
gray
\"
,
\"
rgb
\"
."
)
.
SetDefault
(
"unchanged"
);
}
};
}
// namespace operators
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
REGISTER_OPERATOR
(
decode_jpeg
,
ops
::
DecodeJpegOp
,
ops
::
DecodeJpegOpMaker
,
paddle
::
framework
::
EmptyGradOpMaker
<
paddle
::
framework
::
OpDesc
>
,
paddle
::
framework
::
EmptyGradOpMaker
<
paddle
::
imperative
::
OpBase
>
)
REGISTER_OP_CPU_KERNEL
(
decode_jpeg
,
ops
::
CPUDecodeJpegKernel
<
uint8_t
>
)
paddle/fluid/operators/decode_jpeg_op.cu
0 → 100644
浏览文件 @
b22f6d69
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef PADDLE_WITH_HIP
#include <string>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/dynload/nvjpeg.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/stream/cuda_stream.h"
namespace
paddle
{
namespace
operators
{
static
cudaStream_t
nvjpeg_stream
=
nullptr
;
static
nvjpegHandle_t
nvjpeg_handle
=
nullptr
;
void
InitNvjpegImage
(
nvjpegImage_t
*
img
)
{
for
(
int
c
=
0
;
c
<
NVJPEG_MAX_COMPONENT
;
c
++
)
{
img
->
channel
[
c
]
=
nullptr
;
img
->
pitch
[
c
]
=
0
;
}
}
template
<
typename
T
>
class
GPUDecodeJpegKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
// Create nvJPEG handle
if
(
nvjpeg_handle
==
nullptr
)
{
nvjpegStatus_t
create_status
=
platform
::
dynload
::
nvjpegCreateSimple
(
&
nvjpeg_handle
);
PADDLE_ENFORCE_EQ
(
create_status
,
NVJPEG_STATUS_SUCCESS
,
platform
::
errors
::
Fatal
(
"nvjpegCreateSimple failed: "
,
create_status
));
}
nvjpegJpegState_t
nvjpeg_state
;
nvjpegStatus_t
state_status
=
platform
::
dynload
::
nvjpegJpegStateCreate
(
nvjpeg_handle
,
&
nvjpeg_state
);
PADDLE_ENFORCE_EQ
(
state_status
,
NVJPEG_STATUS_SUCCESS
,
platform
::
errors
::
Fatal
(
"nvjpegJpegStateCreate failed: "
,
state_status
));
int
components
;
nvjpegChromaSubsampling_t
subsampling
;
int
widths
[
NVJPEG_MAX_COMPONENT
];
int
heights
[
NVJPEG_MAX_COMPONENT
];
auto
*
x
=
ctx
.
Input
<
framework
::
Tensor
>
(
"X"
);
auto
*
x_data
=
x
->
data
<
T
>
();
nvjpegStatus_t
info_status
=
platform
::
dynload
::
nvjpegGetImageInfo
(
nvjpeg_handle
,
x_data
,
(
size_t
)
x
->
numel
(),
&
components
,
&
subsampling
,
widths
,
heights
);
PADDLE_ENFORCE_EQ
(
info_status
,
NVJPEG_STATUS_SUCCESS
,
platform
::
errors
::
Fatal
(
"nvjpegGetImageInfo failed: "
,
info_status
));
int
width
=
widths
[
0
];
int
height
=
heights
[
0
];
nvjpegOutputFormat_t
output_format
;
int
output_components
;
auto
mode
=
ctx
.
Attr
<
std
::
string
>
(
"mode"
);
if
(
mode
==
"unchanged"
)
{
if
(
components
==
1
)
{
output_format
=
NVJPEG_OUTPUT_Y
;
output_components
=
1
;
}
else
if
(
components
==
3
)
{
output_format
=
NVJPEG_OUTPUT_RGB
;
output_components
=
3
;
}
else
{
platform
::
dynload
::
nvjpegJpegStateDestroy
(
nvjpeg_state
);
PADDLE_THROW
(
platform
::
errors
::
Fatal
(
"The provided mode is not supported for JPEG files on GPU"
));
}
}
else
if
(
mode
==
"gray"
)
{
output_format
=
NVJPEG_OUTPUT_Y
;
output_components
=
1
;
}
else
if
(
mode
==
"rgb"
)
{
output_format
=
NVJPEG_OUTPUT_RGB
;
output_components
=
3
;
}
else
{
platform
::
dynload
::
nvjpegJpegStateDestroy
(
nvjpeg_state
);
PADDLE_THROW
(
platform
::
errors
::
Fatal
(
"The provided mode is not supported for JPEG files on GPU"
));
}
nvjpegImage_t
out_image
;
InitNvjpegImage
(
&
out_image
);
// create nvjpeg stream
if
(
nvjpeg_stream
==
nullptr
)
{
cudaStreamCreateWithFlags
(
&
nvjpeg_stream
,
cudaStreamNonBlocking
);
}
int
sz
=
widths
[
0
]
*
heights
[
0
];
auto
*
out
=
ctx
.
Output
<
framework
::
LoDTensor
>
(
"Out"
);
std
::
vector
<
int64_t
>
out_shape
=
{
output_components
,
height
,
width
};
out
->
Resize
(
framework
::
make_ddim
(
out_shape
));
T
*
data
=
out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
for
(
int
c
=
0
;
c
<
output_components
;
c
++
)
{
out_image
.
channel
[
c
]
=
data
+
c
*
sz
;
out_image
.
pitch
[
c
]
=
width
;
}
nvjpegStatus_t
decode_status
=
platform
::
dynload
::
nvjpegDecode
(
nvjpeg_handle
,
nvjpeg_state
,
x_data
,
x
->
numel
(),
output_format
,
&
out_image
,
nvjpeg_stream
);
}
};
}
// namespace operators
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
REGISTER_OP_CUDA_KERNEL
(
decode_jpeg
,
ops
::
GPUDecodeJpegKernel
<
uint8_t
>
)
#endif
paddle/fluid/operators/read_file_op.cc
0 → 100644
浏览文件 @
b22f6d69
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <fstream>
#include <string>
#include <vector>
#include "paddle/fluid/framework/generator.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/platform/enforce.h"
namespace
paddle
{
namespace
operators
{
template
<
typename
T
>
class
CPUReadFileKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
auto
filename
=
ctx
.
Attr
<
std
::
string
>
(
"filename"
);
std
::
ifstream
input
(
filename
.
c_str
(),
std
::
ios
::
in
|
std
::
ios
::
binary
|
std
::
ios
::
ate
);
std
::
streamsize
file_size
=
input
.
tellg
();
input
.
seekg
(
0
,
std
::
ios
::
beg
);
auto
*
out
=
ctx
.
Output
<
framework
::
LoDTensor
>
(
"Out"
);
std
::
vector
<
int64_t
>
out_shape
=
{
file_size
};
out
->
Resize
(
framework
::
make_ddim
(
out_shape
));
uint8_t
*
data
=
out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
input
.
read
(
reinterpret_cast
<
char
*>
(
data
),
file_size
);
}
};
class
ReadFileOp
:
public
framework
::
OperatorWithKernel
{
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE_EQ
(
ctx
->
HasOutput
(
"Out"
),
true
,
platform
::
errors
::
InvalidArgument
(
"Output(Out) of ReadFileOp is null."
));
auto
out_dims
=
std
::
vector
<
int
>
(
1
,
-
1
);
ctx
->
SetOutputDim
(
"Out"
,
framework
::
make_ddim
(
out_dims
));
}
protected:
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
OpKernelType
(
framework
::
proto
::
VarType
::
UINT8
,
platform
::
CPUPlace
());
}
};
class
ReadFileOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
public:
void
Make
()
override
{
AddOutput
(
"Out"
,
"The output tensor of ReadFile op"
);
AddComment
(
R"DOC(
This operator read a file.
)DOC"
);
AddAttr
<
std
::
string
>
(
"filename"
,
"Path of the file to be readed."
)
.
SetDefault
({});
}
};
}
// namespace operators
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
REGISTER_OPERATOR
(
read_file
,
ops
::
ReadFileOp
,
ops
::
ReadFileOpMaker
,
paddle
::
framework
::
EmptyGradOpMaker
<
paddle
::
framework
::
OpDesc
>
,
paddle
::
framework
::
EmptyGradOpMaker
<
paddle
::
imperative
::
OpBase
>
)
REGISTER_OP_CPU_KERNEL
(
read_file
,
ops
::
CPUReadFileKernel
<
uint8_t
>
)
paddle/fluid/platform/dynload/CMakeLists.txt
浏览文件 @
b22f6d69
cc_library
(
dynamic_loader SRCS dynamic_loader.cc DEPS glog gflags enforce
)
list
(
APPEND CUDA_SRCS cublas.cc cudnn.cc curand.cc cusolver.cc nvtx.cc
)
list
(
APPEND CUDA_SRCS cublas.cc cudnn.cc curand.cc cusolver.cc nvtx.cc
nvjpeg.cc
)
if
(
WITH_ROCM
)
list
(
APPEND HIP_SRCS rocblas.cc miopen.cc hiprand.cc
)
...
...
paddle/fluid/platform/dynload/dynamic_loader.cc
浏览文件 @
b22f6d69
...
...
@@ -100,6 +100,9 @@ static constexpr char* win_cublas_lib =
static
constexpr
char
*
win_curand_lib
=
"curand64_"
CUDA_VERSION_MAJOR
CUDA_VERSION_MINOR
".dll;curand64_"
CUDA_VERSION_MAJOR
".dll;curand64_10.dll"
;
static
constexpr
char
*
win_nvjpeg_lib
=
"nvjpeg64_"
CUDA_VERSION_MAJOR
CUDA_VERSION_MINOR
".dll;nvjpeg64_"
CUDA_VERSION_MAJOR
".dll;nvjpeg64_10.dll"
;
static
constexpr
char
*
win_cusolver_lib
=
"cusolver64_"
CUDA_VERSION_MAJOR
CUDA_VERSION_MINOR
".dll;cusolver64_"
CUDA_VERSION_MAJOR
".dll;cusolver64_10.dll"
;
...
...
@@ -107,6 +110,9 @@ static constexpr char* win_cusolver_lib =
static
constexpr
char
*
win_curand_lib
=
"curand64_"
CUDA_VERSION_MAJOR
CUDA_VERSION_MINOR
".dll;curand64_"
CUDA_VERSION_MAJOR
".dll"
;
static
constexpr
char
*
win_nvjpeg_lib
=
"nvjpeg64_"
CUDA_VERSION_MAJOR
CUDA_VERSION_MINOR
".dll;nvjpeg64_"
CUDA_VERSION_MAJOR
".dll"
;
static
constexpr
char
*
win_cusolver_lib
=
"cusolver64_"
CUDA_VERSION_MAJOR
CUDA_VERSION_MINOR
".dll;cusolver64_"
CUDA_VERSION_MAJOR
".dll"
;
...
...
@@ -330,6 +336,17 @@ void* GetCurandDsoHandle() {
#endif
}
void
*
GetNvjpegDsoHandle
()
{
#if defined(__APPLE__) || defined(__OSX__)
return
GetDsoHandleFromSearchPath
(
FLAGS_cuda_dir
,
"libnvjpeg.dylib"
);
#elif defined(_WIN32) && defined(PADDLE_WITH_CUDA)
return
GetDsoHandleFromSearchPath
(
FLAGS_cuda_dir
,
win_nvjpeg_lib
,
true
,
{
cuda_lib_path
});
#else
return
GetDsoHandleFromSearchPath
(
FLAGS_cuda_dir
,
"libnvjpeg.so"
);
#endif
}
void
*
GetCusolverDsoHandle
()
{
#if defined(__APPLE__) || defined(__OSX__)
return
GetDsoHandleFromSearchPath
(
FLAGS_cuda_dir
,
"libcusolver.dylib"
);
...
...
paddle/fluid/platform/dynload/dynamic_loader.h
浏览文件 @
b22f6d69
...
...
@@ -29,6 +29,7 @@ void* GetCublasDsoHandle();
void
*
GetCUDNNDsoHandle
();
void
*
GetCUPTIDsoHandle
();
void
*
GetCurandDsoHandle
();
void
*
GetNvjpegDsoHandle
();
void
*
GetCusolverDsoHandle
();
void
*
GetNVRTCDsoHandle
();
void
*
GetCUDADsoHandle
();
...
...
paddle/fluid/platform/dynload/nvjpeg.cc
0 → 100644
浏览文件 @
b22f6d69
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/platform/dynload/nvjpeg.h"
namespace
paddle
{
namespace
platform
{
namespace
dynload
{
std
::
once_flag
nvjpeg_dso_flag
;
void
*
nvjpeg_dso_handle
;
#define DEFINE_WRAP(__name) DynLoad__##__name __name
NVJPEG_RAND_ROUTINE_EACH
(
DEFINE_WRAP
);
}
// namespace dynload
}
// namespace platform
}
// namespace paddle
paddle/fluid/platform/dynload/nvjpeg.h
0 → 100644
浏览文件 @
b22f6d69
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#ifdef PADDLE_WITH_CUDA
#include <nvjpeg.h>
#include <mutex> // NOLINT
#include "paddle/fluid/platform/dynload/dynamic_loader.h"
#include "paddle/fluid/platform/port.h"
namespace
paddle
{
namespace
platform
{
namespace
dynload
{
extern
std
::
once_flag
nvjpeg_dso_flag
;
extern
void
*
nvjpeg_dso_handle
;
#define DECLARE_DYNAMIC_LOAD_NVJPEG_WRAP(__name) \
struct DynLoad__##__name { \
template <typename... Args> \
nvjpegStatus_t operator()(Args... args) { \
using nvjpegFunc = decltype(&::__name); \
std::call_once(nvjpeg_dso_flag, []() { \
nvjpeg_dso_handle = paddle::platform::dynload::GetNvjpegDsoHandle(); \
}); \
static void *p_##__name = dlsym(nvjpeg_dso_handle, #__name); \
return reinterpret_cast<nvjpegFunc>(p_##__name)(args...); \
} \
}; \
extern DynLoad__##__name __name
#define NVJPEG_RAND_ROUTINE_EACH(__macro) \
__macro(nvjpegCreateSimple); \
__macro(nvjpegJpegStateCreate); \
__macro(nvjpegGetImageInfo); \
__macro(nvjpegJpegStateDestroy); \
__macro(nvjpegDecode);
NVJPEG_RAND_ROUTINE_EACH
(
DECLARE_DYNAMIC_LOAD_NVJPEG_WRAP
);
}
// namespace dynload
}
// namespace platform
}
// namespace paddle
#endif
python/paddle/tests/test_read_file.py
0 → 100644
浏览文件 @
b22f6d69
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
cv2
import
shutil
import
unittest
import
numpy
as
np
import
paddle
from
paddle.vision.ops
import
read_file
,
decode_jpeg
class
TestReadFile
(
unittest
.
TestCase
):
def
setUp
(
self
):
fake_img
=
(
np
.
random
.
random
((
400
,
300
,
3
))
*
255
).
astype
(
'uint8'
)
cv2
.
imwrite
(
'fake.jpg'
,
fake_img
)
def
tearDown
(
self
):
os
.
remove
(
'fake.jpg'
)
def
read_file_decode_jpeg
(
self
):
if
not
paddle
.
is_compiled_with_cuda
():
return
img_bytes
=
read_file
(
'fake.jpg'
)
img
=
decode_jpeg
(
img_bytes
,
mode
=
'gray'
)
img
=
decode_jpeg
(
img_bytes
,
mode
=
'rgb'
)
img
=
decode_jpeg
(
img_bytes
)
img_cv2
=
cv2
.
imread
(
'fake.jpg'
)
if
paddle
.
in_dynamic_mode
():
np
.
testing
.
assert_equal
(
img
.
shape
,
img_cv2
.
transpose
(
2
,
0
,
1
).
shape
)
else
:
place
=
paddle
.
CUDAPlace
(
0
)
exe
=
paddle
.
static
.
Executor
(
place
)
exe
.
run
(
paddle
.
static
.
default_startup_program
())
out
=
exe
.
run
(
paddle
.
static
.
default_main_program
(),
fetch_list
=
[
img
])
np
.
testing
.
assert_equal
(
out
[
0
].
shape
,
img_cv2
.
transpose
(
2
,
0
,
1
).
shape
)
def
test_read_file_decode_jpeg_dynamic
(
self
):
self
.
read_file_decode_jpeg
()
def
test_read_file_decode_jpeg_static
(
self
):
paddle
.
enable_static
()
self
.
read_file_decode_jpeg
()
paddle
.
disable_static
()
if
__name__
==
'__main__'
:
unittest
.
main
()
python/paddle/vision/ops.py
浏览文件 @
b22f6d69
...
...
@@ -22,7 +22,10 @@ from ..fluid.initializer import Normal
from
paddle.common_ops_import
import
*
__all__
=
[
'yolo_loss'
,
'yolo_box'
,
'deform_conv2d'
,
'DeformConv2D'
]
__all__
=
[
'yolo_loss'
,
'yolo_box'
,
'deform_conv2d'
,
'DeformConv2D'
,
'read_file'
,
'decode_jpeg'
]
def
yolo_loss
(
x
,
...
...
@@ -782,3 +785,95 @@ class DeformConv2D(Layer):
groups
=
self
.
_groups
,
mask
=
mask
)
return
out
def
read_file
(
filename
,
name
=
None
):
"""
Reads and outputs the bytes contents of a file as a uint8 Tensor
with one dimension.
Args:
filename (str): Path of the file to be read.
name (str, optional): The default value is None. Normally there is no
need for user to set this property. For more information, please
refer to :ref:`api_guide_Name`.
Returns:
A uint8 tensor.
Examples:
.. code-block:: python
import cv2
import paddle
fake_img = (np.random.random(
(400, 300, 3)) * 255).astype('uint8')
cv2.imwrite('fake.jpg', fake_img)
img_bytes = paddle.vision.ops.read_file('fake.jpg')
print(img_bytes.shape)
"""
if
in_dygraph_mode
():
return
core
.
ops
.
read_file
(
'filename'
,
filename
)
inputs
=
dict
()
attrs
=
{
'filename'
:
filename
}
helper
=
LayerHelper
(
"read_file"
,
**
locals
())
out
=
helper
.
create_variable_for_type_inference
(
'uint8'
)
helper
.
append_op
(
type
=
"read_file"
,
inputs
=
inputs
,
attrs
=
attrs
,
outputs
=
{
"Out"
:
out
})
return
out
def
decode_jpeg
(
x
,
mode
=
'unchanged'
,
name
=
None
):
"""
Decodes a JPEG image into a 3 dimensional RGB Tensor or 1 dimensional Gray Tensor.
Optionally converts the image to the desired format.
The values of the output tensor are uint8 between 0 and 255.
Args:
x (Tensor): A one dimensional uint8 tensor containing the raw bytes
of the JPEG image.
mode (str): The read mode used for optionally converting the image.
Default: 'unchanged'.
name (str, optional): The default value is None. Normally there is no
need for user to set this property. For more information, please
refer to :ref:`api_guide_Name`.
Returns:
Tensor: A decoded image tensor with shape (imge_channels, image_height, image_width)
Examples:
.. code-block:: python
import cv2
import paddle
fake_img = (np.random.random(
(400, 300, 3)) * 255).astype('uint8')
cv2.imwrite('fake.jpg', fake_img)
img_bytes = paddle.vision.ops.read_file('fake.jpg')
img = paddle.vision.ops.decode_jpeg(img_bytes)
print(img.shape)
"""
if
in_dygraph_mode
():
return
core
.
ops
.
decode_jpeg
(
x
,
"mode"
,
mode
)
inputs
=
{
'X'
:
x
}
attrs
=
{
"mode"
:
mode
}
helper
=
LayerHelper
(
"decode_jpeg"
,
**
locals
())
out
=
helper
.
create_variable_for_type_inference
(
'uint8'
)
helper
.
append_op
(
type
=
"decode_jpeg"
,
inputs
=
inputs
,
attrs
=
attrs
,
outputs
=
{
"Out"
:
out
})
return
out
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录