From 9968c56321a74c51fb762cb583f80bac6de90e6f Mon Sep 17 00:00:00 2001 From: chenenquan Date: Wed, 9 Mar 2022 11:53:36 +0800 Subject: [PATCH] [Phi] Migrate linspace op to phi (#40124) * [Phi] Migrate linspace op * [Phi] Migrate linspace op * [Phi] Fix linspace op * [PHI] rename data_tranform to data_type_transform * [PHI] Fix DECLARE and PD --- paddle/fluid/operators/linspace_op.cc | 45 ++------ paddle/fluid/operators/linspace_op.cu | 104 ------------------ paddle/fluid/operators/linspace_op.h | 76 ------------- paddle/phi/infermeta/ternary.cc | 29 +++++ paddle/phi/infermeta/ternary.h | 5 + paddle/phi/kernels/cpu/linspace_kernel.cc | 71 ++++++++++++ .../phi/kernels/funcs/data_type_transform.h | 58 ++++++++++ paddle/phi/kernels/gpu/linspace_kernel.cu | 97 ++++++++++++++++ paddle/phi/kernels/linspace_kernel.h | 26 +++++ 9 files changed, 298 insertions(+), 213 deletions(-) delete mode 100644 paddle/fluid/operators/linspace_op.cu delete mode 100644 paddle/fluid/operators/linspace_op.h create mode 100644 paddle/phi/kernels/cpu/linspace_kernel.cc create mode 100644 paddle/phi/kernels/funcs/data_type_transform.h create mode 100644 paddle/phi/kernels/gpu/linspace_kernel.cu create mode 100644 paddle/phi/kernels/linspace_kernel.h diff --git a/paddle/fluid/operators/linspace_op.cc b/paddle/fluid/operators/linspace_op.cc index fe271fa5e8..378c7573d6 100644 --- a/paddle/fluid/operators/linspace_op.cc +++ b/paddle/fluid/operators/linspace_op.cc @@ -12,9 +12,13 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/linspace_op.h" #include + +#include "paddle/fluid/framework/infershape_utils.h" +#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_version_registry.h" +#include "paddle/phi/core/infermeta_utils.h" +#include "paddle/phi/infermeta/ternary.h" namespace paddle { namespace operators { @@ -23,33 +27,6 @@ class LinspaceOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - void InferShape(framework::InferShapeContext *ctx) const override { - OP_INOUT_CHECK(ctx->HasInput("Start"), "Input", "Start", "linspace"); - OP_INOUT_CHECK(ctx->HasInput("Stop"), "Input", "Stop", "linspace"); - OP_INOUT_CHECK(ctx->HasInput("Num"), "Input", "Num", "linspace"); - OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "linspace"); - - auto s_dims = ctx->GetInputDim("Start"); - PADDLE_ENFORCE_EQ((s_dims.size() == 1) && (s_dims[0] == 1), true, - platform::errors::InvalidArgument( - "The shape of Input(Start) must be [1]," - "but received input shape is [%s].", - s_dims)); - auto e_dims = ctx->GetInputDim("Stop"); - PADDLE_ENFORCE_EQ((e_dims.size() == 1) && (e_dims[0] == 1), true, - platform::errors::InvalidArgument( - "The shape of Input(Stop) must be [1]," - "but received input shape is [%s].", - e_dims)); - auto step_dims = ctx->GetInputDim("Num"); - PADDLE_ENFORCE_EQ( - (step_dims.size() == 1) && (step_dims[0] == 1), true, - platform::errors::InvalidArgument("The shape of Input(Num) must be [1]," - "but received input shape is [%s].", - step_dims)); - ctx->SetOutputDim("Out", {-1}); - } - protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { @@ -88,11 +65,13 @@ class LinspaceOpMaker : public framework::OpProtoAndCheckerMaker { } // namespace paddle namespace ops = paddle::operators; -REGISTER_OP_WITHOUT_GRADIENT(linspace, ops::LinspaceOp, ops::LinspaceOpMaker); -REGISTER_OP_CPU_KERNEL(linspace, ops::CPULinspaceKernel, - ops::CPULinspaceKernel, - ops::CPULinspaceKernel, - ops::CPULinspaceKernel); +DECLARE_INFER_SHAPE_FUNCTOR(linspace, LinspaceInferShapeFunctor, + PD_INFER_META(phi::LinspaceInferMeta)); +REGISTER_OPERATOR( + linspace, ops::LinspaceOp, ops::LinspaceOpMaker, + paddle::framework::EmptyGradOpMaker, + paddle::framework::EmptyGradOpMaker, + LinspaceInferShapeFunctor); REGISTER_OP_VERSION(linspace) .AddCheckpoint( diff --git a/paddle/fluid/operators/linspace_op.cu b/paddle/fluid/operators/linspace_op.cu deleted file mode 100644 index aa625a7f5b..0000000000 --- a/paddle/fluid/operators/linspace_op.cu +++ /dev/null @@ -1,104 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/fluid/framework/data_type_transform.h" -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/operators/linspace_op.h" -#include "paddle/fluid/platform/device/gpu/gpu_primitives.h" - -namespace paddle { -namespace operators { - -using Tensor = framework::Tensor; - -template -__global__ void LinspaceKernel(T start, T stop, double step, int64_t size, - T* out) { - int64_t index = blockIdx.x * blockDim.x + threadIdx.x; - - for (; index < size; index += blockDim.x * gridDim.x) { - if (index < size / 2) { - out[index] = static_cast(start + step * index); - } else { - out[index] = static_cast(stop - step * (size - index - 1)); - } - } -} - -template -__global__ void LinspaceSpecialKernel(T start, T* out) { - out[0] = static_cast(start); -} - -template -class CUDALinspaceKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& context) const override { - auto* pre_start = context.Input("Start"); - auto* pre_stop = context.Input("Stop"); - auto* num_t = context.Input("Num"); - auto* out = context.Output("Out"); - auto dtype = static_cast( - context.Attr("dtype")); - - Tensor start_t; - Tensor stop_t; - auto start_dtype = framework::OpKernelType( - framework::TransToProtoVarType(pre_start->dtype()), context.GetPlace()); - auto stop_dtype = framework::OpKernelType( - framework::TransToProtoVarType(pre_stop->dtype()), context.GetPlace()); - auto out_dtype = framework::OpKernelType(dtype, context.GetPlace()); - framework::TransDataType(start_dtype, out_dtype, *pre_start, &start_t); - framework::TransDataType(stop_dtype, out_dtype, *pre_stop, &stop_t); - - framework::Tensor n_start; - framework::Tensor n_stop; - framework::Tensor n_num; - framework::TensorCopy(start_t, platform::CPUPlace(), &n_start); - T start = n_start.data()[0]; - framework::TensorCopy(stop_t, platform::CPUPlace(), &n_stop); - T stop = n_stop.data()[0]; - framework::TensorCopy(*num_t, platform::CPUPlace(), &n_num); - int64_t num = static_cast(n_num.data()[0]); - - PADDLE_ENFORCE_GT(num, 0, platform::errors::InvalidArgument( - "The num of linspace op should be larger " - "than 0, but received num is %d", - num)); - - out->Resize(phi::make_ddim({num})); - T* out_data = out->mutable_data(context.GetPlace()); - - double step = 0; - auto stream = context.cuda_device_context().stream(); - int block = 512; - int grid = (num + block - 1) / block; - if (num != 1) { - step = (static_cast(stop - start)) / (num - 1); - LinspaceKernel<<>>(start, stop, step, num, - out_data); - } else { - LinspaceSpecialKernel<<>>(start, out_data); - } - } -}; - -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; -REGISTER_OP_CUDA_KERNEL(linspace, ops::CUDALinspaceKernel, - ops::CUDALinspaceKernel, - ops::CUDALinspaceKernel, - ops::CUDALinspaceKernel); diff --git a/paddle/fluid/operators/linspace_op.h b/paddle/fluid/operators/linspace_op.h deleted file mode 100644 index ae51f1221c..0000000000 --- a/paddle/fluid/operators/linspace_op.h +++ /dev/null @@ -1,76 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once -#include -#include "paddle/fluid/framework/data_type_transform.h" -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/phi/kernels/funcs/math_function.h" - -namespace paddle { -namespace operators { - -using Tensor = framework::Tensor; - -template -class CPULinspaceKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& context) const override { - auto* pre_start = context.Input("Start"); - auto* pre_stop = context.Input("Stop"); - int32_t num = context.Input("Num")->data()[0]; - auto* out = context.Output("Out"); - auto dtype = static_cast( - context.Attr("dtype")); - - Tensor start_t; - Tensor stop_t; - auto start_dtype = framework::OpKernelType( - framework::TransToProtoVarType(pre_start->dtype()), context.GetPlace()); - auto stop_dtype = framework::OpKernelType( - framework::TransToProtoVarType(pre_stop->dtype()), context.GetPlace()); - auto out_dtype = framework::OpKernelType(dtype, context.GetPlace()); - framework::TransDataType(start_dtype, out_dtype, *pre_start, &start_t); - framework::TransDataType(stop_dtype, out_dtype, *pre_stop, &stop_t); - - T start = start_t.data()[0]; - T stop = stop_t.data()[0]; - PADDLE_ENFORCE_GT(num, 0, platform::errors::InvalidArgument( - "The num of linspace op should be larger " - "than 0, but received num is %d", - num)); - - out->Resize(phi::make_ddim({num})); - - T* out_data = out->mutable_data(context.GetPlace()); - - if (num > 1) { - // step should be of double type for all types - double step = (static_cast(stop - start)) / (num - 1); - int half_num = num / 2; - for (int i = 0; i < num; ++i) { - if (i < half_num) { - out_data[i] = static_cast(start + step * i); - } else { - out_data[i] = static_cast(stop - step * (num - i - 1)); - } - } - } else { - out_data[0] = static_cast(start); - } - } -}; - -} // namespace operators -} // namespace paddle diff --git a/paddle/phi/infermeta/ternary.cc b/paddle/phi/infermeta/ternary.cc index c3472a2480..eb807ad461 100644 --- a/paddle/phi/infermeta/ternary.cc +++ b/paddle/phi/infermeta/ternary.cc @@ -209,4 +209,33 @@ void LerpInferMeta(const MetaTensor& x, out->share_lod(x); } +void LinspaceInferMeta(const MetaTensor& start, + const MetaTensor& stop, + const MetaTensor& number, + MetaTensor* out) { + auto s_dims = start.dims(); + PADDLE_ENFORCE_EQ( + (s_dims.size() == 1) && (s_dims[0] == 1), + true, + phi::errors::InvalidArgument("The shape of Input(Start) must be [1]," + "but received input shape is [%s].", + s_dims)); + auto e_dims = stop.dims(); + PADDLE_ENFORCE_EQ( + (e_dims.size() == 1) && (e_dims[0] == 1), + true, + phi::errors::InvalidArgument("The shape of Input(Stop) must be [1]," + "but received input shape is [%s].", + e_dims)); + auto step_dims = number.dims(); + PADDLE_ENFORCE_EQ( + (step_dims.size() == 1) && (step_dims[0] == 1), + true, + phi::errors::InvalidArgument("The shape of Input(Num) must be [1]," + "but received input shape is [%s].", + step_dims)); + out->set_dims(phi::make_ddim({-1})); + out->set_dtype(start.dtype()); +} + } // namespace phi diff --git a/paddle/phi/infermeta/ternary.h b/paddle/phi/infermeta/ternary.h index cff57e1ba7..4dec144251 100644 --- a/paddle/phi/infermeta/ternary.h +++ b/paddle/phi/infermeta/ternary.h @@ -58,4 +58,9 @@ void LerpInferMeta(const MetaTensor& x, const MetaTensor& weight, MetaTensor* out); +void LinspaceInferMeta(const MetaTensor& start, + const MetaTensor& stop, + const MetaTensor& number, + MetaTensor* out); + } // namespace phi diff --git a/paddle/phi/kernels/cpu/linspace_kernel.cc b/paddle/phi/kernels/cpu/linspace_kernel.cc new file mode 100644 index 0000000000..4b8b7f7a2e --- /dev/null +++ b/paddle/phi/kernels/cpu/linspace_kernel.cc @@ -0,0 +1,71 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/linspace_kernel.h" + +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/data_type_transform.h" + +namespace phi { + +template +void LinspaceKernel(const Context& ctx, + const DenseTensor& start, + const DenseTensor& stop, + const DenseTensor& number, + DataType dtype, + DenseTensor* out) { + int32_t num = number.data()[0]; + auto start_t = phi::funcs::TransDataType(ctx, start, dtype); + auto stop_t = phi::funcs::TransDataType(ctx, stop, dtype); + + T start_data = start_t.template data()[0]; + T stop_data = stop_t.template data()[0]; + PADDLE_ENFORCE_GT( + num, + 0, + phi::errors::InvalidArgument("The num of linspace op should be larger " + "than 0, but received num is %d", + num)); + + out->Resize(phi::make_ddim({num})); + T* out_data = ctx.template Alloc(out); + + if (num > 1) { + // step should be of double type for all types + double step = (static_cast(stop_data - start_data)) / (num - 1); + int half_num = num / 2; + for (int i = 0; i < num; ++i) { + if (i < half_num) { + out_data[i] = static_cast(start_data + step * i); + } else { + out_data[i] = static_cast(stop_data - step * (num - i - 1)); + } + } + } else { + out_data[0] = static_cast(start_data); + } +} + +} // namespace phi + +PD_REGISTER_KERNEL(linspace, + CPU, + ALL_LAYOUT, + phi::LinspaceKernel, + float, + int32_t, + int64_t, + double) {} diff --git a/paddle/phi/kernels/funcs/data_type_transform.h b/paddle/phi/kernels/funcs/data_type_transform.h new file mode 100644 index 0000000000..ad7f2aa192 --- /dev/null +++ b/paddle/phi/kernels/funcs/data_type_transform.h @@ -0,0 +1,58 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/phi/common/data_type.h" +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/kernels/cast_kernel.h" + +namespace phi { +namespace funcs { + +template +phi::DenseTensor TransDataType(const Context& dev_ctx, + const phi::DenseTensor& x, + DataType dtype) { + VLOG(3) << "TransDataType " + << "src type:" << x.dtype() << "; dst typoe: " << dtype; + + switch (x.dtype()) { + case DataType::FLOAT32: + return phi::Cast(dev_ctx, x, dtype); + case DataType::FLOAT64: + return phi::Cast(dev_ctx, x, dtype); + case DataType::INT32: + return phi::Cast(dev_ctx, x, dtype); + case DataType::INT64: + return phi::Cast(dev_ctx, x, dtype); + case DataType::FLOAT16: + return phi::Cast(dev_ctx, x, dtype); + case DataType::BFLOAT16: + return phi::Cast(dev_ctx, x, dtype); + case DataType::BOOL: + return phi::Cast(dev_ctx, x, dtype); + case DataType::INT16: + return phi::Cast(dev_ctx, x, dtype); + case DataType::UINT8: + return phi::Cast(dev_ctx, x, dtype); + default: + PADDLE_THROW(phi::errors::Unimplemented( + "Data type (%s) is not supported when casting data type.", + x.dtype())); + } +} + +} // namespace funcs +} // namespace phi diff --git a/paddle/phi/kernels/gpu/linspace_kernel.cu b/paddle/phi/kernels/gpu/linspace_kernel.cu new file mode 100644 index 0000000000..3a6ff365c1 --- /dev/null +++ b/paddle/phi/kernels/gpu/linspace_kernel.cu @@ -0,0 +1,97 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/linspace_kernel.h" + +#include "paddle/fluid/platform/device/gpu/gpu_primitives.h" +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/copy_kernel.h" +#include "paddle/phi/kernels/funcs/data_type_transform.h" +#include "paddle/phi/kernels/funcs/math_function.h" + +namespace phi { + +template +__global__ void LinspaceKernelInner( + T start, T stop, double step, int64_t size, T* out) { + int64_t index = blockIdx.x * blockDim.x + threadIdx.x; + + for (; index < size; index += blockDim.x * gridDim.x) { + if (index < size / 2) { + out[index] = static_cast(start + step * index); + } else { + out[index] = static_cast(stop - step * (size - index - 1)); + } + } +} + +template +__global__ void LinspaceSpecialKernel(T start, T* out) { + out[0] = static_cast(start); +} + +template +void LinspaceKernel(const Context& ctx, + const DenseTensor& start, + const DenseTensor& stop, + const DenseTensor& number, + DataType dtype, + DenseTensor* out) { + auto start_t = phi::funcs::TransDataType(ctx, start, dtype); + auto stop_t = phi::funcs::TransDataType(ctx, stop, dtype); + + DenseTensor n_start; + DenseTensor n_stop; + DenseTensor n_num; + phi::Copy(ctx, start_t, phi::CPUPlace(), false, &n_start); + T start_data = n_start.data()[0]; + phi::Copy(ctx, stop_t, phi::CPUPlace(), false, &n_stop); + T stop_data = n_stop.data()[0]; + phi::Copy(ctx, number, phi::CPUPlace(), false, &n_num); + int64_t num = static_cast(n_num.data()[0]); + + PADDLE_ENFORCE_GT( + num, + 0, + phi::errors::InvalidArgument("The num of linspace op should be larger " + "than 0, but received num is %d", + num)); + + out->Resize(phi::make_ddim({num})); + T* out_data = ctx.template Alloc(out); + + double step = 0; + auto stream = ctx.stream(); + int block = 512; + int grid = (num + block - 1) / block; + if (num != 1) { + step = (static_cast(stop_data - start_data)) / (num - 1); + LinspaceKernelInner<<>>( + start_data, stop_data, step, num, out_data); + } else { + LinspaceSpecialKernel<<>>(start_data, out_data); + } +} + +} // namespace phi + +PD_REGISTER_KERNEL(linspace, + GPU, + ALL_LAYOUT, + phi::LinspaceKernel, + float, + int32_t, + int64_t, + double) {} diff --git a/paddle/phi/kernels/linspace_kernel.h b/paddle/phi/kernels/linspace_kernel.h new file mode 100644 index 0000000000..ca2b940aef --- /dev/null +++ b/paddle/phi/kernels/linspace_kernel.h @@ -0,0 +1,26 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +void LinspaceKernel(const Context& ctx, + const DenseTensor& start, + const DenseTensor& stop, + const DenseTensor& number, + DataType dtype, + DenseTensor* out); + +} // namespace phi -- GitLab