未验证 提交 b22f6d69 编写于 作者: L LielinJiang 提交者: GitHub

Add op read_file and decode_jpeg (#32564)

* add op read_file and decode_jpeg
上级 7a73692b
......@@ -182,6 +182,7 @@ function(op_library TARGET)
list(REMOVE_ITEM hip_srcs "cholesky_op.cu")
list(REMOVE_ITEM hip_srcs "correlation_op.cu")
list(REMOVE_ITEM hip_srcs "multinomial_op.cu")
list(REMOVE_ITEM hip_srcs "decode_jpeg_op.cu")
hip_library(${TARGET} SRCS ${cc_srcs} ${hip_cc_srcs} ${miopen_cu_cc_srcs} ${miopen_cu_srcs} ${mkldnn_cc_srcs} ${hip_srcs} DEPS ${op_library_DEPS}
${op_common_deps})
else()
......
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <fstream>
#include <string>
#include <vector>
#include "paddle/fluid/framework/generator.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/platform/dynload/nvjpeg.h"
#include "paddle/fluid/platform/enforce.h"
namespace paddle {
namespace operators {
template <typename T>
class CPUDecodeJpegKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
// TODO(LieLinJiang): add cpu implement.
PADDLE_THROW(platform::errors::Unimplemented(
"DecodeJpeg op only supports GPU now."));
}
};
class DecodeJpegOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "DecodeJpeg");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "DecodeJpeg");
auto mode = ctx->Attrs().Get<std::string>("mode");
std::vector<int> out_dims;
if (mode == "unchanged") {
out_dims = {-1, -1, -1};
} else if (mode == "gray") {
out_dims = {1, -1, -1};
} else if (mode == "rgb") {
out_dims = {3, -1, -1};
} else {
PADDLE_THROW(platform::errors::Fatal(
"The provided mode is not supported for JPEG files on GPU: ", mode));
}
ctx->SetOutputDim("Out", framework::make_ddim(out_dims));
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace());
}
framework::OpKernelType GetKernelTypeForVar(
const std::string& var_name, const framework::Tensor& tensor,
const framework::OpKernelType& expected_kernel_type) const {
if (var_name == "X") {
return expected_kernel_type;
}
return framework::OpKernelType(tensor.type(), tensor.place(),
tensor.layout());
}
};
class DecodeJpegOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X",
"A one dimensional uint8 tensor containing the raw bytes "
"of the JPEG image. It is a tensor with rank 1.");
AddOutput("Out", "The output tensor of DecodeJpeg op");
AddComment(R"DOC(
This operator decodes a JPEG image into a 3 dimensional RGB Tensor
or 1 dimensional Gray Tensor. Optionally converts the image to the
desired format. The values of the output tensor are uint8 between 0
and 255.
)DOC");
AddAttr<std::string>(
"mode",
"(string, default \"unchanged\"), The read mode used "
"for optionally converting the image, can be \"unchanged\" "
",\"gray\" , \"rgb\" .")
.SetDefault("unchanged");
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(
decode_jpeg, ops::DecodeJpegOp, ops::DecodeJpegOpMaker,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>)
REGISTER_OP_CPU_KERNEL(decode_jpeg, ops::CPUDecodeJpegKernel<uint8_t>)
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef PADDLE_WITH_HIP
#include <string>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/dynload/nvjpeg.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/stream/cuda_stream.h"
namespace paddle {
namespace operators {
static cudaStream_t nvjpeg_stream = nullptr;
static nvjpegHandle_t nvjpeg_handle = nullptr;
void InitNvjpegImage(nvjpegImage_t* img) {
for (int c = 0; c < NVJPEG_MAX_COMPONENT; c++) {
img->channel[c] = nullptr;
img->pitch[c] = 0;
}
}
template <typename T>
class GPUDecodeJpegKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
// Create nvJPEG handle
if (nvjpeg_handle == nullptr) {
nvjpegStatus_t create_status =
platform::dynload::nvjpegCreateSimple(&nvjpeg_handle);
PADDLE_ENFORCE_EQ(create_status, NVJPEG_STATUS_SUCCESS,
platform::errors::Fatal("nvjpegCreateSimple failed: ",
create_status));
}
nvjpegJpegState_t nvjpeg_state;
nvjpegStatus_t state_status =
platform::dynload::nvjpegJpegStateCreate(nvjpeg_handle, &nvjpeg_state);
PADDLE_ENFORCE_EQ(state_status, NVJPEG_STATUS_SUCCESS,
platform::errors::Fatal("nvjpegJpegStateCreate failed: ",
state_status));
int components;
nvjpegChromaSubsampling_t subsampling;
int widths[NVJPEG_MAX_COMPONENT];
int heights[NVJPEG_MAX_COMPONENT];
auto* x = ctx.Input<framework::Tensor>("X");
auto* x_data = x->data<T>();
nvjpegStatus_t info_status = platform::dynload::nvjpegGetImageInfo(
nvjpeg_handle, x_data, (size_t)x->numel(), &components, &subsampling,
widths, heights);
PADDLE_ENFORCE_EQ(
info_status, NVJPEG_STATUS_SUCCESS,
platform::errors::Fatal("nvjpegGetImageInfo failed: ", info_status));
int width = widths[0];
int height = heights[0];
nvjpegOutputFormat_t output_format;
int output_components;
auto mode = ctx.Attr<std::string>("mode");
if (mode == "unchanged") {
if (components == 1) {
output_format = NVJPEG_OUTPUT_Y;
output_components = 1;
} else if (components == 3) {
output_format = NVJPEG_OUTPUT_RGB;
output_components = 3;
} else {
platform::dynload::nvjpegJpegStateDestroy(nvjpeg_state);
PADDLE_THROW(platform::errors::Fatal(
"The provided mode is not supported for JPEG files on GPU"));
}
} else if (mode == "gray") {
output_format = NVJPEG_OUTPUT_Y;
output_components = 1;
} else if (mode == "rgb") {
output_format = NVJPEG_OUTPUT_RGB;
output_components = 3;
} else {
platform::dynload::nvjpegJpegStateDestroy(nvjpeg_state);
PADDLE_THROW(platform::errors::Fatal(
"The provided mode is not supported for JPEG files on GPU"));
}
nvjpegImage_t out_image;
InitNvjpegImage(&out_image);
// create nvjpeg stream
if (nvjpeg_stream == nullptr) {
cudaStreamCreateWithFlags(&nvjpeg_stream, cudaStreamNonBlocking);
}
int sz = widths[0] * heights[0];
auto* out = ctx.Output<framework::LoDTensor>("Out");
std::vector<int64_t> out_shape = {output_components, height, width};
out->Resize(framework::make_ddim(out_shape));
T* data = out->mutable_data<T>(ctx.GetPlace());
for (int c = 0; c < output_components; c++) {
out_image.channel[c] = data + c * sz;
out_image.pitch[c] = width;
}
nvjpegStatus_t decode_status = platform::dynload::nvjpegDecode(
nvjpeg_handle, nvjpeg_state, x_data, x->numel(), output_format,
&out_image, nvjpeg_stream);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(decode_jpeg, ops::GPUDecodeJpegKernel<uint8_t>)
#endif
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <fstream>
#include <string>
#include <vector>
#include "paddle/fluid/framework/generator.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/platform/enforce.h"
namespace paddle {
namespace operators {
template <typename T>
class CPUReadFileKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto filename = ctx.Attr<std::string>("filename");
std::ifstream input(filename.c_str(),
std::ios::in | std::ios::binary | std::ios::ate);
std::streamsize file_size = input.tellg();
input.seekg(0, std::ios::beg);
auto* out = ctx.Output<framework::LoDTensor>("Out");
std::vector<int64_t> out_shape = {file_size};
out->Resize(framework::make_ddim(out_shape));
uint8_t* data = out->mutable_data<T>(ctx.GetPlace());
input.read(reinterpret_cast<char*>(data), file_size);
}
};
class ReadFileOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true,
platform::errors::InvalidArgument(
"Output(Out) of ReadFileOp is null."));
auto out_dims = std::vector<int>(1, -1);
ctx->SetOutputDim("Out", framework::make_ddim(out_dims));
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(framework::proto::VarType::UINT8,
platform::CPUPlace());
}
};
class ReadFileOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddOutput("Out", "The output tensor of ReadFile op");
AddComment(R"DOC(
This operator read a file.
)DOC");
AddAttr<std::string>("filename", "Path of the file to be readed.")
.SetDefault({});
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(
read_file, ops::ReadFileOp, ops::ReadFileOpMaker,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>)
REGISTER_OP_CPU_KERNEL(read_file, ops::CPUReadFileKernel<uint8_t>)
cc_library(dynamic_loader SRCS dynamic_loader.cc DEPS glog gflags enforce)
list(APPEND CUDA_SRCS cublas.cc cudnn.cc curand.cc cusolver.cc nvtx.cc)
list(APPEND CUDA_SRCS cublas.cc cudnn.cc curand.cc cusolver.cc nvtx.cc nvjpeg.cc)
if (WITH_ROCM)
list(APPEND HIP_SRCS rocblas.cc miopen.cc hiprand.cc)
......
......@@ -100,6 +100,9 @@ static constexpr char* win_cublas_lib =
static constexpr char* win_curand_lib =
"curand64_" CUDA_VERSION_MAJOR CUDA_VERSION_MINOR
".dll;curand64_" CUDA_VERSION_MAJOR ".dll;curand64_10.dll";
static constexpr char* win_nvjpeg_lib =
"nvjpeg64_" CUDA_VERSION_MAJOR CUDA_VERSION_MINOR
".dll;nvjpeg64_" CUDA_VERSION_MAJOR ".dll;nvjpeg64_10.dll";
static constexpr char* win_cusolver_lib =
"cusolver64_" CUDA_VERSION_MAJOR CUDA_VERSION_MINOR
".dll;cusolver64_" CUDA_VERSION_MAJOR ".dll;cusolver64_10.dll";
......@@ -107,6 +110,9 @@ static constexpr char* win_cusolver_lib =
static constexpr char* win_curand_lib =
"curand64_" CUDA_VERSION_MAJOR CUDA_VERSION_MINOR
".dll;curand64_" CUDA_VERSION_MAJOR ".dll";
static constexpr char* win_nvjpeg_lib =
"nvjpeg64_" CUDA_VERSION_MAJOR CUDA_VERSION_MINOR
".dll;nvjpeg64_" CUDA_VERSION_MAJOR ".dll";
static constexpr char* win_cusolver_lib =
"cusolver64_" CUDA_VERSION_MAJOR CUDA_VERSION_MINOR
".dll;cusolver64_" CUDA_VERSION_MAJOR ".dll";
......@@ -330,6 +336,17 @@ void* GetCurandDsoHandle() {
#endif
}
void* GetNvjpegDsoHandle() {
#if defined(__APPLE__) || defined(__OSX__)
return GetDsoHandleFromSearchPath(FLAGS_cuda_dir, "libnvjpeg.dylib");
#elif defined(_WIN32) && defined(PADDLE_WITH_CUDA)
return GetDsoHandleFromSearchPath(FLAGS_cuda_dir, win_nvjpeg_lib, true,
{cuda_lib_path});
#else
return GetDsoHandleFromSearchPath(FLAGS_cuda_dir, "libnvjpeg.so");
#endif
}
void* GetCusolverDsoHandle() {
#if defined(__APPLE__) || defined(__OSX__)
return GetDsoHandleFromSearchPath(FLAGS_cuda_dir, "libcusolver.dylib");
......
......@@ -29,6 +29,7 @@ void* GetCublasDsoHandle();
void* GetCUDNNDsoHandle();
void* GetCUPTIDsoHandle();
void* GetCurandDsoHandle();
void* GetNvjpegDsoHandle();
void* GetCusolverDsoHandle();
void* GetNVRTCDsoHandle();
void* GetCUDADsoHandle();
......
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/platform/dynload/nvjpeg.h"
namespace paddle {
namespace platform {
namespace dynload {
std::once_flag nvjpeg_dso_flag;
void *nvjpeg_dso_handle;
#define DEFINE_WRAP(__name) DynLoad__##__name __name
NVJPEG_RAND_ROUTINE_EACH(DEFINE_WRAP);
} // namespace dynload
} // namespace platform
} // namespace paddle
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#ifdef PADDLE_WITH_CUDA
#include <nvjpeg.h>
#include <mutex> // NOLINT
#include "paddle/fluid/platform/dynload/dynamic_loader.h"
#include "paddle/fluid/platform/port.h"
namespace paddle {
namespace platform {
namespace dynload {
extern std::once_flag nvjpeg_dso_flag;
extern void *nvjpeg_dso_handle;
#define DECLARE_DYNAMIC_LOAD_NVJPEG_WRAP(__name) \
struct DynLoad__##__name { \
template <typename... Args> \
nvjpegStatus_t operator()(Args... args) { \
using nvjpegFunc = decltype(&::__name); \
std::call_once(nvjpeg_dso_flag, []() { \
nvjpeg_dso_handle = paddle::platform::dynload::GetNvjpegDsoHandle(); \
}); \
static void *p_##__name = dlsym(nvjpeg_dso_handle, #__name); \
return reinterpret_cast<nvjpegFunc>(p_##__name)(args...); \
} \
}; \
extern DynLoad__##__name __name
#define NVJPEG_RAND_ROUTINE_EACH(__macro) \
__macro(nvjpegCreateSimple); \
__macro(nvjpegJpegStateCreate); \
__macro(nvjpegGetImageInfo); \
__macro(nvjpegJpegStateDestroy); \
__macro(nvjpegDecode);
NVJPEG_RAND_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_NVJPEG_WRAP);
} // namespace dynload
} // namespace platform
} // namespace paddle
#endif
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import cv2
import shutil
import unittest
import numpy as np
import paddle
from paddle.vision.ops import read_file, decode_jpeg
class TestReadFile(unittest.TestCase):
def setUp(self):
fake_img = (np.random.random((400, 300, 3)) * 255).astype('uint8')
cv2.imwrite('fake.jpg', fake_img)
def tearDown(self):
os.remove('fake.jpg')
def read_file_decode_jpeg(self):
if not paddle.is_compiled_with_cuda():
return
img_bytes = read_file('fake.jpg')
img = decode_jpeg(img_bytes, mode='gray')
img = decode_jpeg(img_bytes, mode='rgb')
img = decode_jpeg(img_bytes)
img_cv2 = cv2.imread('fake.jpg')
if paddle.in_dynamic_mode():
np.testing.assert_equal(img.shape, img_cv2.transpose(2, 0, 1).shape)
else:
place = paddle.CUDAPlace(0)
exe = paddle.static.Executor(place)
exe.run(paddle.static.default_startup_program())
out = exe.run(paddle.static.default_main_program(),
fetch_list=[img])
np.testing.assert_equal(out[0].shape,
img_cv2.transpose(2, 0, 1).shape)
def test_read_file_decode_jpeg_dynamic(self):
self.read_file_decode_jpeg()
def test_read_file_decode_jpeg_static(self):
paddle.enable_static()
self.read_file_decode_jpeg()
paddle.disable_static()
if __name__ == '__main__':
unittest.main()
......@@ -22,7 +22,10 @@ from ..fluid.initializer import Normal
from paddle.common_ops_import import *
__all__ = ['yolo_loss', 'yolo_box', 'deform_conv2d', 'DeformConv2D']
__all__ = [
'yolo_loss', 'yolo_box', 'deform_conv2d', 'DeformConv2D', 'read_file',
'decode_jpeg'
]
def yolo_loss(x,
......@@ -782,3 +785,95 @@ class DeformConv2D(Layer):
groups=self._groups,
mask=mask)
return out
def read_file(filename, name=None):
"""
Reads and outputs the bytes contents of a file as a uint8 Tensor
with one dimension.
Args:
filename (str): Path of the file to be read.
name (str, optional): The default value is None. Normally there is no
need for user to set this property. For more information, please
refer to :ref:`api_guide_Name`.
Returns:
A uint8 tensor.
Examples:
.. code-block:: python
import cv2
import paddle
fake_img = (np.random.random(
(400, 300, 3)) * 255).astype('uint8')
cv2.imwrite('fake.jpg', fake_img)
img_bytes = paddle.vision.ops.read_file('fake.jpg')
print(img_bytes.shape)
"""
if in_dygraph_mode():
return core.ops.read_file('filename', filename)
inputs = dict()
attrs = {'filename': filename}
helper = LayerHelper("read_file", **locals())
out = helper.create_variable_for_type_inference('uint8')
helper.append_op(
type="read_file", inputs=inputs, attrs=attrs, outputs={"Out": out})
return out
def decode_jpeg(x, mode='unchanged', name=None):
"""
Decodes a JPEG image into a 3 dimensional RGB Tensor or 1 dimensional Gray Tensor.
Optionally converts the image to the desired format.
The values of the output tensor are uint8 between 0 and 255.
Args:
x (Tensor): A one dimensional uint8 tensor containing the raw bytes
of the JPEG image.
mode (str): The read mode used for optionally converting the image.
Default: 'unchanged'.
name (str, optional): The default value is None. Normally there is no
need for user to set this property. For more information, please
refer to :ref:`api_guide_Name`.
Returns:
Tensor: A decoded image tensor with shape (imge_channels, image_height, image_width)
Examples:
.. code-block:: python
import cv2
import paddle
fake_img = (np.random.random(
(400, 300, 3)) * 255).astype('uint8')
cv2.imwrite('fake.jpg', fake_img)
img_bytes = paddle.vision.ops.read_file('fake.jpg')
img = paddle.vision.ops.decode_jpeg(img_bytes)
print(img.shape)
"""
if in_dygraph_mode():
return core.ops.decode_jpeg(x, "mode", mode)
inputs = {'X': x}
attrs = {"mode": mode}
helper = LayerHelper("decode_jpeg", **locals())
out = helper.create_variable_for_type_inference('uint8')
helper.append_op(
type="decode_jpeg", inputs=inputs, attrs=attrs, outputs={"Out": out})
return out
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册