未验证 提交 9968c563 编写于 作者: C chenenquan 提交者: GitHub

[Phi] Migrate linspace op to phi (#40124)

* [Phi] Migrate linspace op

* [Phi] Migrate linspace op

* [Phi] Fix linspace op

* [PHI] rename data_tranform to data_type_transform

* [PHI] Fix DECLARE and PD
上级 2037fa68
...@@ -12,9 +12,13 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,9 +12,13 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/linspace_op.h"
#include <string> #include <string>
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_version_registry.h" #include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/ternary.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -23,33 +27,6 @@ class LinspaceOp : public framework::OperatorWithKernel { ...@@ -23,33 +27,6 @@ class LinspaceOp : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("Start"), "Input", "Start", "linspace");
OP_INOUT_CHECK(ctx->HasInput("Stop"), "Input", "Stop", "linspace");
OP_INOUT_CHECK(ctx->HasInput("Num"), "Input", "Num", "linspace");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "linspace");
auto s_dims = ctx->GetInputDim("Start");
PADDLE_ENFORCE_EQ((s_dims.size() == 1) && (s_dims[0] == 1), true,
platform::errors::InvalidArgument(
"The shape of Input(Start) must be [1],"
"but received input shape is [%s].",
s_dims));
auto e_dims = ctx->GetInputDim("Stop");
PADDLE_ENFORCE_EQ((e_dims.size() == 1) && (e_dims[0] == 1), true,
platform::errors::InvalidArgument(
"The shape of Input(Stop) must be [1],"
"but received input shape is [%s].",
e_dims));
auto step_dims = ctx->GetInputDim("Num");
PADDLE_ENFORCE_EQ(
(step_dims.size() == 1) && (step_dims[0] == 1), true,
platform::errors::InvalidArgument("The shape of Input(Num) must be [1],"
"but received input shape is [%s].",
step_dims));
ctx->SetOutputDim("Out", {-1});
}
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext &ctx) const override {
...@@ -88,11 +65,13 @@ class LinspaceOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -88,11 +65,13 @@ class LinspaceOpMaker : public framework::OpProtoAndCheckerMaker {
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_WITHOUT_GRADIENT(linspace, ops::LinspaceOp, ops::LinspaceOpMaker); DECLARE_INFER_SHAPE_FUNCTOR(linspace, LinspaceInferShapeFunctor,
REGISTER_OP_CPU_KERNEL(linspace, ops::CPULinspaceKernel<float>, PD_INFER_META(phi::LinspaceInferMeta));
ops::CPULinspaceKernel<int32_t>, REGISTER_OPERATOR(
ops::CPULinspaceKernel<int64_t>, linspace, ops::LinspaceOp, ops::LinspaceOpMaker,
ops::CPULinspaceKernel<double>); paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>,
LinspaceInferShapeFunctor);
REGISTER_OP_VERSION(linspace) REGISTER_OP_VERSION(linspace)
.AddCheckpoint( .AddCheckpoint(
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/data_type_transform.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/linspace_op.h"
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename T>
__global__ void LinspaceKernel(T start, T stop, double step, int64_t size,
T* out) {
int64_t index = blockIdx.x * blockDim.x + threadIdx.x;
for (; index < size; index += blockDim.x * gridDim.x) {
if (index < size / 2) {
out[index] = static_cast<T>(start + step * index);
} else {
out[index] = static_cast<T>(stop - step * (size - index - 1));
}
}
}
template <typename T>
__global__ void LinspaceSpecialKernel(T start, T* out) {
out[0] = static_cast<T>(start);
}
template <typename T>
class CUDALinspaceKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* pre_start = context.Input<framework::Tensor>("Start");
auto* pre_stop = context.Input<framework::Tensor>("Stop");
auto* num_t = context.Input<framework::Tensor>("Num");
auto* out = context.Output<framework::Tensor>("Out");
auto dtype = static_cast<framework::proto::VarType::Type>(
context.Attr<int>("dtype"));
Tensor start_t;
Tensor stop_t;
auto start_dtype = framework::OpKernelType(
framework::TransToProtoVarType(pre_start->dtype()), context.GetPlace());
auto stop_dtype = framework::OpKernelType(
framework::TransToProtoVarType(pre_stop->dtype()), context.GetPlace());
auto out_dtype = framework::OpKernelType(dtype, context.GetPlace());
framework::TransDataType(start_dtype, out_dtype, *pre_start, &start_t);
framework::TransDataType(stop_dtype, out_dtype, *pre_stop, &stop_t);
framework::Tensor n_start;
framework::Tensor n_stop;
framework::Tensor n_num;
framework::TensorCopy(start_t, platform::CPUPlace(), &n_start);
T start = n_start.data<T>()[0];
framework::TensorCopy(stop_t, platform::CPUPlace(), &n_stop);
T stop = n_stop.data<T>()[0];
framework::TensorCopy(*num_t, platform::CPUPlace(), &n_num);
int64_t num = static_cast<int64_t>(n_num.data<int32_t>()[0]);
PADDLE_ENFORCE_GT(num, 0, platform::errors::InvalidArgument(
"The num of linspace op should be larger "
"than 0, but received num is %d",
num));
out->Resize(phi::make_ddim({num}));
T* out_data = out->mutable_data<T>(context.GetPlace());
double step = 0;
auto stream = context.cuda_device_context().stream();
int block = 512;
int grid = (num + block - 1) / block;
if (num != 1) {
step = (static_cast<double>(stop - start)) / (num - 1);
LinspaceKernel<T><<<grid, block, 0, stream>>>(start, stop, step, num,
out_data);
} else {
LinspaceSpecialKernel<T><<<grid, block, 0, stream>>>(start, out_data);
}
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(linspace, ops::CUDALinspaceKernel<float>,
ops::CUDALinspaceKernel<int32_t>,
ops::CUDALinspaceKernel<int64_t>,
ops::CUDALinspaceKernel<double>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <functional>
#include "paddle/fluid/framework/data_type_transform.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/kernels/funcs/math_function.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename T>
class CPULinspaceKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* pre_start = context.Input<framework::Tensor>("Start");
auto* pre_stop = context.Input<framework::Tensor>("Stop");
int32_t num = context.Input<framework::Tensor>("Num")->data<int32_t>()[0];
auto* out = context.Output<framework::Tensor>("Out");
auto dtype = static_cast<framework::proto::VarType::Type>(
context.Attr<int>("dtype"));
Tensor start_t;
Tensor stop_t;
auto start_dtype = framework::OpKernelType(
framework::TransToProtoVarType(pre_start->dtype()), context.GetPlace());
auto stop_dtype = framework::OpKernelType(
framework::TransToProtoVarType(pre_stop->dtype()), context.GetPlace());
auto out_dtype = framework::OpKernelType(dtype, context.GetPlace());
framework::TransDataType(start_dtype, out_dtype, *pre_start, &start_t);
framework::TransDataType(stop_dtype, out_dtype, *pre_stop, &stop_t);
T start = start_t.data<T>()[0];
T stop = stop_t.data<T>()[0];
PADDLE_ENFORCE_GT(num, 0, platform::errors::InvalidArgument(
"The num of linspace op should be larger "
"than 0, but received num is %d",
num));
out->Resize(phi::make_ddim({num}));
T* out_data = out->mutable_data<T>(context.GetPlace());
if (num > 1) {
// step should be of double type for all types
double step = (static_cast<double>(stop - start)) / (num - 1);
int half_num = num / 2;
for (int i = 0; i < num; ++i) {
if (i < half_num) {
out_data[i] = static_cast<T>(start + step * i);
} else {
out_data[i] = static_cast<T>(stop - step * (num - i - 1));
}
}
} else {
out_data[0] = static_cast<T>(start);
}
}
};
} // namespace operators
} // namespace paddle
...@@ -209,4 +209,33 @@ void LerpInferMeta(const MetaTensor& x, ...@@ -209,4 +209,33 @@ void LerpInferMeta(const MetaTensor& x,
out->share_lod(x); out->share_lod(x);
} }
void LinspaceInferMeta(const MetaTensor& start,
const MetaTensor& stop,
const MetaTensor& number,
MetaTensor* out) {
auto s_dims = start.dims();
PADDLE_ENFORCE_EQ(
(s_dims.size() == 1) && (s_dims[0] == 1),
true,
phi::errors::InvalidArgument("The shape of Input(Start) must be [1],"
"but received input shape is [%s].",
s_dims));
auto e_dims = stop.dims();
PADDLE_ENFORCE_EQ(
(e_dims.size() == 1) && (e_dims[0] == 1),
true,
phi::errors::InvalidArgument("The shape of Input(Stop) must be [1],"
"but received input shape is [%s].",
e_dims));
auto step_dims = number.dims();
PADDLE_ENFORCE_EQ(
(step_dims.size() == 1) && (step_dims[0] == 1),
true,
phi::errors::InvalidArgument("The shape of Input(Num) must be [1],"
"but received input shape is [%s].",
step_dims));
out->set_dims(phi::make_ddim({-1}));
out->set_dtype(start.dtype());
}
} // namespace phi } // namespace phi
...@@ -58,4 +58,9 @@ void LerpInferMeta(const MetaTensor& x, ...@@ -58,4 +58,9 @@ void LerpInferMeta(const MetaTensor& x,
const MetaTensor& weight, const MetaTensor& weight,
MetaTensor* out); MetaTensor* out);
void LinspaceInferMeta(const MetaTensor& start,
const MetaTensor& stop,
const MetaTensor& number,
MetaTensor* out);
} // namespace phi } // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/linspace_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/data_type_transform.h"
namespace phi {
template <typename T, typename Context>
void LinspaceKernel(const Context& ctx,
const DenseTensor& start,
const DenseTensor& stop,
const DenseTensor& number,
DataType dtype,
DenseTensor* out) {
int32_t num = number.data<int32_t>()[0];
auto start_t = phi::funcs::TransDataType(ctx, start, dtype);
auto stop_t = phi::funcs::TransDataType(ctx, stop, dtype);
T start_data = start_t.template data<T>()[0];
T stop_data = stop_t.template data<T>()[0];
PADDLE_ENFORCE_GT(
num,
0,
phi::errors::InvalidArgument("The num of linspace op should be larger "
"than 0, but received num is %d",
num));
out->Resize(phi::make_ddim({num}));
T* out_data = ctx.template Alloc<T>(out);
if (num > 1) {
// step should be of double type for all types
double step = (static_cast<double>(stop_data - start_data)) / (num - 1);
int half_num = num / 2;
for (int i = 0; i < num; ++i) {
if (i < half_num) {
out_data[i] = static_cast<T>(start_data + step * i);
} else {
out_data[i] = static_cast<T>(stop_data - step * (num - i - 1));
}
}
} else {
out_data[0] = static_cast<T>(start_data);
}
}
} // namespace phi
PD_REGISTER_KERNEL(linspace,
CPU,
ALL_LAYOUT,
phi::LinspaceKernel,
float,
int32_t,
int64_t,
double) {}
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/kernels/cast_kernel.h"
namespace phi {
namespace funcs {
template <typename Context>
phi::DenseTensor TransDataType(const Context& dev_ctx,
const phi::DenseTensor& x,
DataType dtype) {
VLOG(3) << "TransDataType "
<< "src type:" << x.dtype() << "; dst typoe: " << dtype;
switch (x.dtype()) {
case DataType::FLOAT32:
return phi::Cast<float>(dev_ctx, x, dtype);
case DataType::FLOAT64:
return phi::Cast<double>(dev_ctx, x, dtype);
case DataType::INT32:
return phi::Cast<int32_t>(dev_ctx, x, dtype);
case DataType::INT64:
return phi::Cast<int64_t>(dev_ctx, x, dtype);
case DataType::FLOAT16:
return phi::Cast<phi::dtype::float16>(dev_ctx, x, dtype);
case DataType::BFLOAT16:
return phi::Cast<phi::dtype::bfloat16>(dev_ctx, x, dtype);
case DataType::BOOL:
return phi::Cast<bool>(dev_ctx, x, dtype);
case DataType::INT16:
return phi::Cast<int16_t>(dev_ctx, x, dtype);
case DataType::UINT8:
return phi::Cast<uint8_t>(dev_ctx, x, dtype);
default:
PADDLE_THROW(phi::errors::Unimplemented(
"Data type (%s) is not supported when casting data type.",
x.dtype()));
}
}
} // namespace funcs
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/linspace_kernel.h"
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/copy_kernel.h"
#include "paddle/phi/kernels/funcs/data_type_transform.h"
#include "paddle/phi/kernels/funcs/math_function.h"
namespace phi {
template <typename T>
__global__ void LinspaceKernelInner(
T start, T stop, double step, int64_t size, T* out) {
int64_t index = blockIdx.x * blockDim.x + threadIdx.x;
for (; index < size; index += blockDim.x * gridDim.x) {
if (index < size / 2) {
out[index] = static_cast<T>(start + step * index);
} else {
out[index] = static_cast<T>(stop - step * (size - index - 1));
}
}
}
template <typename T>
__global__ void LinspaceSpecialKernel(T start, T* out) {
out[0] = static_cast<T>(start);
}
template <typename T, typename Context>
void LinspaceKernel(const Context& ctx,
const DenseTensor& start,
const DenseTensor& stop,
const DenseTensor& number,
DataType dtype,
DenseTensor* out) {
auto start_t = phi::funcs::TransDataType(ctx, start, dtype);
auto stop_t = phi::funcs::TransDataType(ctx, stop, dtype);
DenseTensor n_start;
DenseTensor n_stop;
DenseTensor n_num;
phi::Copy(ctx, start_t, phi::CPUPlace(), false, &n_start);
T start_data = n_start.data<T>()[0];
phi::Copy(ctx, stop_t, phi::CPUPlace(), false, &n_stop);
T stop_data = n_stop.data<T>()[0];
phi::Copy(ctx, number, phi::CPUPlace(), false, &n_num);
int64_t num = static_cast<int64_t>(n_num.data<int32_t>()[0]);
PADDLE_ENFORCE_GT(
num,
0,
phi::errors::InvalidArgument("The num of linspace op should be larger "
"than 0, but received num is %d",
num));
out->Resize(phi::make_ddim({num}));
T* out_data = ctx.template Alloc<T>(out);
double step = 0;
auto stream = ctx.stream();
int block = 512;
int grid = (num + block - 1) / block;
if (num != 1) {
step = (static_cast<double>(stop_data - start_data)) / (num - 1);
LinspaceKernelInner<T><<<grid, block, 0, stream>>>(
start_data, stop_data, step, num, out_data);
} else {
LinspaceSpecialKernel<T><<<grid, block, 0, stream>>>(start_data, out_data);
}
}
} // namespace phi
PD_REGISTER_KERNEL(linspace,
GPU,
ALL_LAYOUT,
phi::LinspaceKernel,
float,
int32_t,
int64_t,
double) {}
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
template <typename T, typename Context>
void LinspaceKernel(const Context& ctx,
const DenseTensor& start,
const DenseTensor& stop,
const DenseTensor& number,
DataType dtype,
DenseTensor* out);
} // namespace phi
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册