未验证 提交 a2265028 编写于 作者: W wuyefeilin 提交者: GitHub

[Phi] mv decode_jpeg (#44645)

* mv kernel

* mv infershape

* mv yaml

* update some

* maintain decoe_jpeg in old dygraph

* fix as review

* rm decode_jpeg_op.cu

* update for rocm
上级 a9f3719b
......@@ -17,48 +17,20 @@
#include <vector>
#include "paddle/fluid/framework/generator.h"
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/unary.h"
namespace paddle {
namespace operators {
template <typename T>
class CPUDecodeJpegKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
// TODO(LieLinJiang): add cpu implement.
PADDLE_THROW(platform::errors::Unimplemented(
"DecodeJpeg op only supports GPU now."));
}
};
class DecodeJpegOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "DecodeJpeg");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "DecodeJpeg");
auto mode = ctx->Attrs().Get<std::string>("mode");
std::vector<int> out_dims;
if (mode == "unchanged") {
out_dims = {-1, -1, -1};
} else if (mode == "gray") {
out_dims = {1, -1, -1};
} else if (mode == "rgb") {
out_dims = {3, -1, -1};
} else {
PADDLE_THROW(platform::errors::Fatal(
"The provided mode is not supported for JPEG files on GPU: ", mode));
}
ctx->SetOutputDim("Out", phi::make_ddim(out_dims));
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
......@@ -107,12 +79,14 @@ and 255.
} // namespace paddle
namespace ops = paddle::operators;
DECLARE_INFER_SHAPE_FUNCTOR(decode_jpeg,
DecodeJpegInferShapeFunctor,
PD_INFER_META(phi::DecodeJpegInferMeta));
REGISTER_OPERATOR(
decode_jpeg,
ops::DecodeJpegOp,
ops::DecodeJpegOpMaker,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>)
REGISTER_OP_CPU_KERNEL(decode_jpeg, ops::CPUDecodeJpegKernel<uint8_t>)
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>,
DecodeJpegInferShapeFunctor)
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#if !defined(WITH_NV_JETSON) && !defined(PADDLE_WITH_HIP)
#include <string>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/dynload/nvjpeg.h"
#include "paddle/fluid/platform/enforce.h"
namespace paddle {
namespace operators {
static cudaStream_t nvjpeg_stream = nullptr;
static nvjpegHandle_t nvjpeg_handle = nullptr;
void InitNvjpegImage(nvjpegImage_t* img) {
for (int c = 0; c < NVJPEG_MAX_COMPONENT; c++) {
img->channel[c] = nullptr;
img->pitch[c] = 0;
}
}
template <typename T>
class GPUDecodeJpegKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
// Create nvJPEG handle
if (nvjpeg_handle == nullptr) {
nvjpegStatus_t create_status =
platform::dynload::nvjpegCreateSimple(&nvjpeg_handle);
PADDLE_ENFORCE_EQ(create_status,
NVJPEG_STATUS_SUCCESS,
platform::errors::Fatal("nvjpegCreateSimple failed: ",
create_status));
}
nvjpegJpegState_t nvjpeg_state;
nvjpegStatus_t state_status =
platform::dynload::nvjpegJpegStateCreate(nvjpeg_handle, &nvjpeg_state);
PADDLE_ENFORCE_EQ(state_status,
NVJPEG_STATUS_SUCCESS,
platform::errors::Fatal("nvjpegJpegStateCreate failed: ",
state_status));
int components;
nvjpegChromaSubsampling_t subsampling;
int widths[NVJPEG_MAX_COMPONENT];
int heights[NVJPEG_MAX_COMPONENT];
auto* x = ctx.Input<framework::Tensor>("X");
auto* x_data = x->data<T>();
nvjpegStatus_t info_status =
platform::dynload::nvjpegGetImageInfo(nvjpeg_handle,
x_data,
static_cast<size_t>(x->numel()),
&components,
&subsampling,
widths,
heights);
PADDLE_ENFORCE_EQ(
info_status,
NVJPEG_STATUS_SUCCESS,
platform::errors::Fatal("nvjpegGetImageInfo failed: ", info_status));
int width = widths[0];
int height = heights[0];
nvjpegOutputFormat_t output_format;
int output_components;
auto mode = ctx.Attr<std::string>("mode");
if (mode == "unchanged") {
if (components == 1) {
output_format = NVJPEG_OUTPUT_Y;
output_components = 1;
} else if (components == 3) {
output_format = NVJPEG_OUTPUT_RGB;
output_components = 3;
} else {
platform::dynload::nvjpegJpegStateDestroy(nvjpeg_state);
PADDLE_THROW(platform::errors::Fatal(
"The provided mode is not supported for JPEG files on GPU"));
}
} else if (mode == "gray") {
output_format = NVJPEG_OUTPUT_Y;
output_components = 1;
} else if (mode == "rgb") {
output_format = NVJPEG_OUTPUT_RGB;
output_components = 3;
} else {
platform::dynload::nvjpegJpegStateDestroy(nvjpeg_state);
PADDLE_THROW(platform::errors::Fatal(
"The provided mode is not supported for JPEG files on GPU"));
}
nvjpegImage_t out_image;
InitNvjpegImage(&out_image);
// create nvjpeg stream
if (nvjpeg_stream == nullptr) {
cudaStreamCreateWithFlags(&nvjpeg_stream, cudaStreamNonBlocking);
}
int sz = widths[0] * heights[0];
auto* out = ctx.Output<framework::LoDTensor>("Out");
std::vector<int64_t> out_shape = {output_components, height, width};
out->Resize(phi::make_ddim(out_shape));
T* data = out->mutable_data<T>(ctx.GetPlace());
for (int c = 0; c < output_components; c++) {
out_image.channel[c] = data + c * sz;
out_image.pitch[c] = width;
}
nvjpegStatus_t decode_status =
platform::dynload::nvjpegDecode(nvjpeg_handle,
nvjpeg_state,
x_data,
x->numel(),
output_format,
&out_image,
nvjpeg_stream);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(decode_jpeg, ops::GPUDecodeJpegKernel<uint8_t>)
#endif
......@@ -585,6 +585,15 @@
func : cumsum
backward : cumsum_grad
# decode_jpeg
- api : decode_jpeg
args : (Tensor x, str mode)
output : Tensor(out)
infer_meta :
func : DecodeJpegInferMeta
kernel :
func : decode_jpeg
- api : deformable_conv
args : (Tensor x, Tensor offset, Tensor filter, Tensor mask, int[] strides, int[] paddings, int[] dilations, int deformable_groups, int groups, int im2col_step)
output : Tensor(out)
......
......@@ -341,6 +341,27 @@ void CropTensorInferMeta(const MetaTensor& x,
out->set_dtype(x.dtype());
}
void DecodeJpegInferMeta(const MetaTensor& x,
const std::string& mode,
MetaTensor* out) {
std::vector<int> out_dims;
if (mode == "unchanged") {
out_dims = {-1, -1, -1};
} else if (mode == "gray") {
out_dims = {1, -1, -1};
} else if (mode == "rgb") {
out_dims = {3, -1, -1};
} else {
errors::Fatal("The provided mode is not supported for JPEG files on GPU: ",
mode);
}
if (out != nullptr) {
out->set_dims(phi::make_ddim(out_dims));
out->set_dtype(x.dtype());
}
}
void DiagEmbedInferMeta(
const MetaTensor& x, int offset, int dim1, int dim2, MetaTensor* out) {
auto x_dims = x.dims();
......@@ -1546,11 +1567,10 @@ void MaxPoolWithIndexInferMeta(const MetaTensor& x,
auto x_dims = x.dims();
PADDLE_ENFORCE(
x_dims.size() == 4 || x_dims.size() == 5,
errors::InvalidArgument(
"Pooling intput should be 4-D or 5-D tensor but received %dD-Tensor",
x_dims.size()));
PADDLE_ENFORCE(x_dims.size() == 4 || x_dims.size() == 5,
errors::InvalidArgument("Pooling intput should be 4-D or "
"5-D tensor but received %dD-Tensor",
x_dims.size()));
if (global_pooling) {
kernel_size_.resize(static_cast<size_t>(x_dims.size()) - 2);
......@@ -3032,7 +3052,8 @@ void StridedSliceInferMeta(const MetaTensor& x,
}
/* Why not use SumRawInferMeta directly?
Because we need make InferMetaFunction's args follow the design of api.yaml
Because we need make InferMetaFunction's args follow the design of
api.yaml
*/
void SumInferMeta(const MetaTensor& x,
const std::vector<int64_t>& axis,
......
......@@ -79,6 +79,10 @@ void CumInferMeta(const MetaTensor& x,
bool reverse,
MetaTensor* out);
void DecodeJpegInferMeta(const MetaTensor& x,
const std::string& mode,
MetaTensor* out);
void DiagEmbedInferMeta(
const MetaTensor& x, int offset, int dim1, int dim2, MetaTensor* out);
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/decode_jpeg_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
namespace phi {
template <typename T, typename Context>
void DecodeJpegKernel(const Context& dev_ctx,
const DenseTensor& x,
const std::string& mode,
DenseTensor* out) {
PADDLE_THROW(errors::Unimplemented("DecodeJpeg op only supports GPU now."));
}
} // namespace phi
PD_REGISTER_KERNEL(
decode_jpeg, CPU, ALL_LAYOUT, phi::DecodeJpegKernel, uint8_t) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
template <typename T, typename Context>
void DecodeJpegKernel(const Context& dev_ctx,
const DenseTensor& x,
const std::string& mode,
DenseTensor* out);
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#if !defined(WITH_NV_JETSON) && !defined(PADDLE_WITH_HIP)
#include "paddle/phi/kernels/decode_jpeg_kernel.h"
#include "paddle/phi/backends/dynload/nvjpeg.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/backends/stream.h"
#include "paddle/phi/core/kernel_registry.h"
namespace phi {
static cudaStream_t nvjpeg_stream = nullptr;
static nvjpegHandle_t nvjpeg_handle = nullptr;
void InitNvjpegImage(nvjpegImage_t* img) {
for (int c = 0; c < NVJPEG_MAX_COMPONENT; c++) {
img->channel[c] = nullptr;
img->pitch[c] = 0;
}
}
template <typename T, typename Context>
void DecodeJpegKernel(const Context& dev_ctx,
const DenseTensor& x,
const std::string& mode,
DenseTensor* out) {
// Create nvJPEG handle
if (nvjpeg_handle == nullptr) {
nvjpegStatus_t create_status =
phi::dynload::nvjpegCreateSimple(&nvjpeg_handle);
PADDLE_ENFORCE_EQ(
create_status,
NVJPEG_STATUS_SUCCESS,
errors::Fatal("nvjpegCreateSimple failed: ", create_status));
}
nvjpegJpegState_t nvjpeg_state;
nvjpegStatus_t state_status =
phi::dynload::nvjpegJpegStateCreate(nvjpeg_handle, &nvjpeg_state);
PADDLE_ENFORCE_EQ(
state_status,
NVJPEG_STATUS_SUCCESS,
errors::Fatal("nvjpegJpegStateCreate failed: ", state_status));
int components;
nvjpegChromaSubsampling_t subsampling;
int widths[NVJPEG_MAX_COMPONENT];
int heights[NVJPEG_MAX_COMPONENT];
auto* x_data = x.data<T>();
nvjpegStatus_t info_status =
phi::dynload::nvjpegGetImageInfo(nvjpeg_handle,
x_data,
(std::size_t)x.numel(),
&components,
&subsampling,
widths,
heights);
PADDLE_ENFORCE_EQ(info_status,
NVJPEG_STATUS_SUCCESS,
errors::Fatal("nvjpegGetImageInfo failed: ", info_status));
int width = widths[0];
int height = heights[0];
nvjpegOutputFormat_t output_format;
int output_components;
if (mode == "unchanged") {
if (components == 1) {
output_format = NVJPEG_OUTPUT_Y;
output_components = 1;
} else if (components == 3) {
output_format = NVJPEG_OUTPUT_RGB;
output_components = 3;
} else {
phi::dynload::nvjpegJpegStateDestroy(nvjpeg_state);
PADDLE_THROW(errors::Fatal(
"The provided mode is not supported for JPEG files on GPU"));
}
} else if (mode == "gray") {
output_format = NVJPEG_OUTPUT_Y;
output_components = 1;
} else if (mode == "rgb") {
output_format = NVJPEG_OUTPUT_RGB;
output_components = 3;
} else {
phi::dynload::nvjpegJpegStateDestroy(nvjpeg_state);
PADDLE_THROW(errors::Fatal(
"The provided mode is not supported for JPEG files on GPU"));
}
nvjpegImage_t out_image;
InitNvjpegImage(&out_image);
// create nvjpeg stream
if (nvjpeg_stream == nullptr) {
cudaStreamCreateWithFlags(&nvjpeg_stream, cudaStreamNonBlocking);
}
int sz = widths[0] * heights[0];
std::vector<int64_t> out_shape = {output_components, height, width};
out->Resize(phi::make_ddim(out_shape));
T* data = dev_ctx.template Alloc<T>(out);
for (int c = 0; c < output_components; c++) {
out_image.channel[c] = data + c * sz;
out_image.pitch[c] = width;
}
nvjpegStatus_t decode_status = phi::dynload::nvjpegDecode(nvjpeg_handle,
nvjpeg_state,
x_data,
x.numel(),
output_format,
&out_image,
nvjpeg_stream);
}
} // namespace phi
PD_REGISTER_KERNEL(decode_jpeg, // cuda_only
GPU,
ALL_LAYOUT,
phi::DecodeJpegKernel,
uint8_t) {
kernel->InputAt(0).SetBackend(phi::Backend::ALL_BACKEND);
}
#endif
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/core/compat/op_utils.h"
namespace phi {
KernelSignature DecodeJpegOpArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature("decode_jpeg", {"X"}, {"mode"}, {"Out"});
}
} // namespace phi
PD_REGISTER_ARG_MAPPING_FN(decode_jpeg, phi::DecodeJpegOpArgumentMapping);
......@@ -1030,7 +1030,6 @@ def decode_jpeg(x, mode='unchanged', name=None):
print(img.shape)
"""
if _non_static_mode():
return _C_ops.decode_jpeg(x, "mode", mode)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册