diff --git a/paddle/fluid/operators/decode_jpeg_op.cc b/paddle/fluid/operators/decode_jpeg_op.cc index 0bf691170740a1327989038ab6c89f9d48ed36cf..789976c887bbeb9d3a374ac2c15a25a53352436f 100644 --- a/paddle/fluid/operators/decode_jpeg_op.cc +++ b/paddle/fluid/operators/decode_jpeg_op.cc @@ -17,48 +17,20 @@ #include #include "paddle/fluid/framework/generator.h" +#include "paddle/fluid/framework/infershape_utils.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/operator.h" #include "paddle/fluid/platform/enforce.h" +#include "paddle/phi/core/infermeta_utils.h" +#include "paddle/phi/infermeta/unary.h" namespace paddle { namespace operators { -template -class CPUDecodeJpegKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - // TODO(LieLinJiang): add cpu implement. - PADDLE_THROW(platform::errors::Unimplemented( - "DecodeJpeg op only supports GPU now.")); - } -}; - class DecodeJpegOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - void InferShape(framework::InferShapeContext* ctx) const override { - OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "DecodeJpeg"); - OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "DecodeJpeg"); - - auto mode = ctx->Attrs().Get("mode"); - std::vector out_dims; - - if (mode == "unchanged") { - out_dims = {-1, -1, -1}; - } else if (mode == "gray") { - out_dims = {1, -1, -1}; - } else if (mode == "rgb") { - out_dims = {3, -1, -1}; - } else { - PADDLE_THROW(platform::errors::Fatal( - "The provided mode is not supported for JPEG files on GPU: ", mode)); - } - - ctx->SetOutputDim("Out", phi::make_ddim(out_dims)); - } - protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { @@ -107,12 +79,14 @@ and 255. } // namespace paddle namespace ops = paddle::operators; +DECLARE_INFER_SHAPE_FUNCTOR(decode_jpeg, + DecodeJpegInferShapeFunctor, + PD_INFER_META(phi::DecodeJpegInferMeta)); REGISTER_OPERATOR( decode_jpeg, ops::DecodeJpegOp, ops::DecodeJpegOpMaker, paddle::framework::EmptyGradOpMaker, - paddle::framework::EmptyGradOpMaker) - -REGISTER_OP_CPU_KERNEL(decode_jpeg, ops::CPUDecodeJpegKernel) + paddle::framework::EmptyGradOpMaker, + DecodeJpegInferShapeFunctor) diff --git a/paddle/fluid/operators/decode_jpeg_op.cu b/paddle/fluid/operators/decode_jpeg_op.cu deleted file mode 100644 index 589611292c912078b941c3d9120b03b9c1542361..0000000000000000000000000000000000000000 --- a/paddle/fluid/operators/decode_jpeg_op.cu +++ /dev/null @@ -1,151 +0,0 @@ -// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#if !defined(WITH_NV_JETSON) && !defined(PADDLE_WITH_HIP) - -#include - -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/platform/dynload/nvjpeg.h" -#include "paddle/fluid/platform/enforce.h" - -namespace paddle { -namespace operators { - -static cudaStream_t nvjpeg_stream = nullptr; -static nvjpegHandle_t nvjpeg_handle = nullptr; - -void InitNvjpegImage(nvjpegImage_t* img) { - for (int c = 0; c < NVJPEG_MAX_COMPONENT; c++) { - img->channel[c] = nullptr; - img->pitch[c] = 0; - } -} - -template -class GPUDecodeJpegKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - // Create nvJPEG handle - if (nvjpeg_handle == nullptr) { - nvjpegStatus_t create_status = - platform::dynload::nvjpegCreateSimple(&nvjpeg_handle); - - PADDLE_ENFORCE_EQ(create_status, - NVJPEG_STATUS_SUCCESS, - platform::errors::Fatal("nvjpegCreateSimple failed: ", - create_status)); - } - - nvjpegJpegState_t nvjpeg_state; - nvjpegStatus_t state_status = - platform::dynload::nvjpegJpegStateCreate(nvjpeg_handle, &nvjpeg_state); - - PADDLE_ENFORCE_EQ(state_status, - NVJPEG_STATUS_SUCCESS, - platform::errors::Fatal("nvjpegJpegStateCreate failed: ", - state_status)); - - int components; - nvjpegChromaSubsampling_t subsampling; - int widths[NVJPEG_MAX_COMPONENT]; - int heights[NVJPEG_MAX_COMPONENT]; - - auto* x = ctx.Input("X"); - auto* x_data = x->data(); - - nvjpegStatus_t info_status = - platform::dynload::nvjpegGetImageInfo(nvjpeg_handle, - x_data, - static_cast(x->numel()), - &components, - &subsampling, - widths, - heights); - - PADDLE_ENFORCE_EQ( - info_status, - NVJPEG_STATUS_SUCCESS, - platform::errors::Fatal("nvjpegGetImageInfo failed: ", info_status)); - - int width = widths[0]; - int height = heights[0]; - - nvjpegOutputFormat_t output_format; - int output_components; - - auto mode = ctx.Attr("mode"); - if (mode == "unchanged") { - if (components == 1) { - output_format = NVJPEG_OUTPUT_Y; - output_components = 1; - } else if (components == 3) { - output_format = NVJPEG_OUTPUT_RGB; - output_components = 3; - } else { - platform::dynload::nvjpegJpegStateDestroy(nvjpeg_state); - PADDLE_THROW(platform::errors::Fatal( - "The provided mode is not supported for JPEG files on GPU")); - } - } else if (mode == "gray") { - output_format = NVJPEG_OUTPUT_Y; - output_components = 1; - } else if (mode == "rgb") { - output_format = NVJPEG_OUTPUT_RGB; - output_components = 3; - } else { - platform::dynload::nvjpegJpegStateDestroy(nvjpeg_state); - PADDLE_THROW(platform::errors::Fatal( - "The provided mode is not supported for JPEG files on GPU")); - } - - nvjpegImage_t out_image; - InitNvjpegImage(&out_image); - - // create nvjpeg stream - if (nvjpeg_stream == nullptr) { - cudaStreamCreateWithFlags(&nvjpeg_stream, cudaStreamNonBlocking); - } - - int sz = widths[0] * heights[0]; - - auto* out = ctx.Output("Out"); - std::vector out_shape = {output_components, height, width}; - out->Resize(phi::make_ddim(out_shape)); - - T* data = out->mutable_data(ctx.GetPlace()); - - for (int c = 0; c < output_components; c++) { - out_image.channel[c] = data + c * sz; - out_image.pitch[c] = width; - } - - nvjpegStatus_t decode_status = - platform::dynload::nvjpegDecode(nvjpeg_handle, - nvjpeg_state, - x_data, - x->numel(), - output_format, - &out_image, - nvjpeg_stream); - } -}; - -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; -REGISTER_OP_CUDA_KERNEL(decode_jpeg, ops::GPUDecodeJpegKernel) - -#endif diff --git a/paddle/phi/api/yaml/legacy_api.yaml b/paddle/phi/api/yaml/legacy_api.yaml index fdd86857ed7ab1658e24517c1f8d69af3a291818..be8602b74ac8581fe5cee28f74a6cda9fb59ef16 100755 --- a/paddle/phi/api/yaml/legacy_api.yaml +++ b/paddle/phi/api/yaml/legacy_api.yaml @@ -585,6 +585,15 @@ func : cumsum backward : cumsum_grad +# decode_jpeg +- api : decode_jpeg + args : (Tensor x, str mode) + output : Tensor(out) + infer_meta : + func : DecodeJpegInferMeta + kernel : + func : decode_jpeg + - api : deformable_conv args : (Tensor x, Tensor offset, Tensor filter, Tensor mask, int[] strides, int[] paddings, int[] dilations, int deformable_groups, int groups, int im2col_step) output : Tensor(out) diff --git a/paddle/phi/infermeta/unary.cc b/paddle/phi/infermeta/unary.cc index 8389615f38684e5e825af223a541a869ed6966f0..a659909df91a7e39d41c45a746fd72fde3f68d60 100644 --- a/paddle/phi/infermeta/unary.cc +++ b/paddle/phi/infermeta/unary.cc @@ -341,6 +341,27 @@ void CropTensorInferMeta(const MetaTensor& x, out->set_dtype(x.dtype()); } +void DecodeJpegInferMeta(const MetaTensor& x, + const std::string& mode, + MetaTensor* out) { + std::vector out_dims; + + if (mode == "unchanged") { + out_dims = {-1, -1, -1}; + } else if (mode == "gray") { + out_dims = {1, -1, -1}; + } else if (mode == "rgb") { + out_dims = {3, -1, -1}; + } else { + errors::Fatal("The provided mode is not supported for JPEG files on GPU: ", + mode); + } + if (out != nullptr) { + out->set_dims(phi::make_ddim(out_dims)); + out->set_dtype(x.dtype()); + } +} + void DiagEmbedInferMeta( const MetaTensor& x, int offset, int dim1, int dim2, MetaTensor* out) { auto x_dims = x.dims(); @@ -1546,11 +1567,10 @@ void MaxPoolWithIndexInferMeta(const MetaTensor& x, auto x_dims = x.dims(); - PADDLE_ENFORCE( - x_dims.size() == 4 || x_dims.size() == 5, - errors::InvalidArgument( - "Pooling intput should be 4-D or 5-D tensor but received %dD-Tensor", - x_dims.size())); + PADDLE_ENFORCE(x_dims.size() == 4 || x_dims.size() == 5, + errors::InvalidArgument("Pooling intput should be 4-D or " + "5-D tensor but received %dD-Tensor", + x_dims.size())); if (global_pooling) { kernel_size_.resize(static_cast(x_dims.size()) - 2); @@ -3032,7 +3052,8 @@ void StridedSliceInferMeta(const MetaTensor& x, } /* Why not use SumRawInferMeta directly? - Because we need make InferMetaFunction's args follow the design of api.yaml + Because we need make InferMetaFunction's args follow the design of + api.yaml */ void SumInferMeta(const MetaTensor& x, const std::vector& axis, diff --git a/paddle/phi/infermeta/unary.h b/paddle/phi/infermeta/unary.h index 72e6be818a22762c5c2592877676dc01297ab6cc..b8fe4a22052bcab7e162ad8de47d1789fd68604e 100644 --- a/paddle/phi/infermeta/unary.h +++ b/paddle/phi/infermeta/unary.h @@ -79,6 +79,10 @@ void CumInferMeta(const MetaTensor& x, bool reverse, MetaTensor* out); +void DecodeJpegInferMeta(const MetaTensor& x, + const std::string& mode, + MetaTensor* out); + void DiagEmbedInferMeta( const MetaTensor& x, int offset, int dim1, int dim2, MetaTensor* out); diff --git a/paddle/phi/kernels/cpu/decode_jpeg_kernel.cc b/paddle/phi/kernels/cpu/decode_jpeg_kernel.cc new file mode 100644 index 0000000000000000000000000000000000000000..aceced1ce853134b0027934ec28a1bdd6881d92d --- /dev/null +++ b/paddle/phi/kernels/cpu/decode_jpeg_kernel.cc @@ -0,0 +1,32 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/decode_jpeg_kernel.h" + +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/kernel_registry.h" + +namespace phi { + +template +void DecodeJpegKernel(const Context& dev_ctx, + const DenseTensor& x, + const std::string& mode, + DenseTensor* out) { + PADDLE_THROW(errors::Unimplemented("DecodeJpeg op only supports GPU now.")); +} +} // namespace phi + +PD_REGISTER_KERNEL( + decode_jpeg, CPU, ALL_LAYOUT, phi::DecodeJpegKernel, uint8_t) {} diff --git a/paddle/phi/kernels/decode_jpeg_kernel.h b/paddle/phi/kernels/decode_jpeg_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..55227516eb644227c82d7aa7dfb239f222976611 --- /dev/null +++ b/paddle/phi/kernels/decode_jpeg_kernel.h @@ -0,0 +1,26 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +void DecodeJpegKernel(const Context& dev_ctx, + const DenseTensor& x, + const std::string& mode, + DenseTensor* out); +} // namespace phi diff --git a/paddle/phi/kernels/gpu/decode_jpeg_kernel.cu b/paddle/phi/kernels/gpu/decode_jpeg_kernel.cu new file mode 100644 index 0000000000000000000000000000000000000000..0b5a10b93d85a156c12b8b522ac52da256d9c656 --- /dev/null +++ b/paddle/phi/kernels/gpu/decode_jpeg_kernel.cu @@ -0,0 +1,148 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#if !defined(WITH_NV_JETSON) && !defined(PADDLE_WITH_HIP) + +#include "paddle/phi/kernels/decode_jpeg_kernel.h" + +#include "paddle/phi/backends/dynload/nvjpeg.h" +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/backends/stream.h" +#include "paddle/phi/core/kernel_registry.h" + +namespace phi { + +static cudaStream_t nvjpeg_stream = nullptr; +static nvjpegHandle_t nvjpeg_handle = nullptr; + +void InitNvjpegImage(nvjpegImage_t* img) { + for (int c = 0; c < NVJPEG_MAX_COMPONENT; c++) { + img->channel[c] = nullptr; + img->pitch[c] = 0; + } +} + +template +void DecodeJpegKernel(const Context& dev_ctx, + const DenseTensor& x, + const std::string& mode, + DenseTensor* out) { + // Create nvJPEG handle + if (nvjpeg_handle == nullptr) { + nvjpegStatus_t create_status = + phi::dynload::nvjpegCreateSimple(&nvjpeg_handle); + + PADDLE_ENFORCE_EQ( + create_status, + NVJPEG_STATUS_SUCCESS, + errors::Fatal("nvjpegCreateSimple failed: ", create_status)); + } + + nvjpegJpegState_t nvjpeg_state; + nvjpegStatus_t state_status = + phi::dynload::nvjpegJpegStateCreate(nvjpeg_handle, &nvjpeg_state); + + PADDLE_ENFORCE_EQ( + state_status, + NVJPEG_STATUS_SUCCESS, + errors::Fatal("nvjpegJpegStateCreate failed: ", state_status)); + + int components; + nvjpegChromaSubsampling_t subsampling; + int widths[NVJPEG_MAX_COMPONENT]; + int heights[NVJPEG_MAX_COMPONENT]; + + auto* x_data = x.data(); + + nvjpegStatus_t info_status = + phi::dynload::nvjpegGetImageInfo(nvjpeg_handle, + x_data, + (std::size_t)x.numel(), + &components, + &subsampling, + widths, + heights); + PADDLE_ENFORCE_EQ(info_status, + NVJPEG_STATUS_SUCCESS, + errors::Fatal("nvjpegGetImageInfo failed: ", info_status)); + + int width = widths[0]; + int height = heights[0]; + + nvjpegOutputFormat_t output_format; + int output_components; + + if (mode == "unchanged") { + if (components == 1) { + output_format = NVJPEG_OUTPUT_Y; + output_components = 1; + } else if (components == 3) { + output_format = NVJPEG_OUTPUT_RGB; + output_components = 3; + } else { + phi::dynload::nvjpegJpegStateDestroy(nvjpeg_state); + PADDLE_THROW(errors::Fatal( + "The provided mode is not supported for JPEG files on GPU")); + } + } else if (mode == "gray") { + output_format = NVJPEG_OUTPUT_Y; + output_components = 1; + } else if (mode == "rgb") { + output_format = NVJPEG_OUTPUT_RGB; + output_components = 3; + } else { + phi::dynload::nvjpegJpegStateDestroy(nvjpeg_state); + PADDLE_THROW(errors::Fatal( + "The provided mode is not supported for JPEG files on GPU")); + } + + nvjpegImage_t out_image; + InitNvjpegImage(&out_image); + + // create nvjpeg stream + if (nvjpeg_stream == nullptr) { + cudaStreamCreateWithFlags(&nvjpeg_stream, cudaStreamNonBlocking); + } + + int sz = widths[0] * heights[0]; + + std::vector out_shape = {output_components, height, width}; + out->Resize(phi::make_ddim(out_shape)); + + T* data = dev_ctx.template Alloc(out); + + for (int c = 0; c < output_components; c++) { + out_image.channel[c] = data + c * sz; + out_image.pitch[c] = width; + } + + nvjpegStatus_t decode_status = phi::dynload::nvjpegDecode(nvjpeg_handle, + nvjpeg_state, + x_data, + x.numel(), + output_format, + &out_image, + nvjpeg_stream); +} +} // namespace phi + +PD_REGISTER_KERNEL(decode_jpeg, // cuda_only + GPU, + ALL_LAYOUT, + phi::DecodeJpegKernel, + uint8_t) { + kernel->InputAt(0).SetBackend(phi::Backend::ALL_BACKEND); +} + +#endif diff --git a/paddle/phi/ops/compat/decode_jpeg_sig.cc b/paddle/phi/ops/compat/decode_jpeg_sig.cc new file mode 100644 index 0000000000000000000000000000000000000000..8c3d2399caeccdb2e6366782b393c6b0d4cc4bdc --- /dev/null +++ b/paddle/phi/ops/compat/decode_jpeg_sig.cc @@ -0,0 +1,25 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/core/compat/op_utils.h" + +namespace phi { + +KernelSignature DecodeJpegOpArgumentMapping(const ArgumentMappingContext& ctx) { + return KernelSignature("decode_jpeg", {"X"}, {"mode"}, {"Out"}); +} + +} // namespace phi + +PD_REGISTER_ARG_MAPPING_FN(decode_jpeg, phi::DecodeJpegOpArgumentMapping); diff --git a/python/paddle/vision/ops.py b/python/paddle/vision/ops.py index 20ea1417857163e1d1790906a65611f597c124ec..484fcf95cb269b403a0a769308e0b57118cdebb4 100755 --- a/python/paddle/vision/ops.py +++ b/python/paddle/vision/ops.py @@ -1030,7 +1030,6 @@ def decode_jpeg(x, mode='unchanged', name=None): print(img.shape) """ - if _non_static_mode(): return _C_ops.decode_jpeg(x, "mode", mode)