未验证 提交 31729a62 编写于 作者: L Liu-xiandong 提交者: GitHub

[phi] modify the shape OP and move inferMeta of shape,matrix_pow,multi_dot (#40506)

* [phi] move matrix_power op

* MatrixInverse fluid -> phi

* modify the CMake to fix compile bug

* delete useless comment

* mutable memory -> phi Alloc

* modify the include file

* modify the include file

* fix bug in CI compiler

* [phi]modify the shape OP and move inferMeta of shape,matrix_pow,multi_dot

* delete useless comment

* fix bug in CI

* modify after review
上级 9bdee437
...@@ -14,8 +14,11 @@ ...@@ -14,8 +14,11 @@
#include <memory> #include <memory>
#include <vector> #include <vector>
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/tensor_util.h" #include "paddle/fluid/framework/tensor_util.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/unary.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -23,26 +26,6 @@ namespace operators { ...@@ -23,26 +26,6 @@ namespace operators {
class MatrixPowerOp : public framework::OperatorWithKernel { class MatrixPowerOp : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "matrix_power");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "matrix_power");
auto dims = ctx->GetInputDim("X");
auto n_dim = dims.size();
PADDLE_ENFORCE_GE(n_dim, 2,
platform::errors::InvalidArgument(
"The Input(X) should have at least 2 dimensions. But "
"received a %d dimension tensor.",
n_dim));
PADDLE_ENFORCE_EQ(dims[n_dim - 2], dims[n_dim - 1],
platform::errors::InvalidArgument(
"The inner-most 2 dimensions of Input(X) all should "
"be square matrices "
"But received X's shape[-2] = %d and shape[-1] = %d.",
dims[n_dim - 2], dims[n_dim - 1]));
ctx->SetOutputDim("Out", dims);
ctx->ShareLoD("X", /*->*/ "Out");
}
}; };
class MatrixPowerOpMaker : public framework::OpProtoAndCheckerMaker { class MatrixPowerOpMaker : public framework::OpProtoAndCheckerMaker {
...@@ -116,9 +99,14 @@ class MatrixPowerGradOpMaker : public framework::SingleGradOpMaker<T> { ...@@ -116,9 +99,14 @@ class MatrixPowerGradOpMaker : public framework::SingleGradOpMaker<T> {
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
DECLARE_INFER_SHAPE_FUNCTOR(matrix_power, MatrixPowerInferShapeFunctor,
PD_INFER_META(phi::MatrixPowerInferMeta));
REGISTER_OPERATOR(matrix_power, ops::MatrixPowerOp, ops::MatrixPowerOpMaker, REGISTER_OPERATOR(matrix_power, ops::MatrixPowerOp, ops::MatrixPowerOpMaker,
ops::MatrixPowerOpInferVarType, ops::MatrixPowerOpInferVarType,
ops::MatrixPowerGradOpMaker<paddle::framework::OpDesc>, ops::MatrixPowerGradOpMaker<paddle::framework::OpDesc>,
ops::MatrixPowerGradOpMaker<paddle::imperative::OpBase>); ops::MatrixPowerGradOpMaker<paddle::imperative::OpBase>,
MatrixPowerInferShapeFunctor);
REGISTER_OPERATOR(matrix_power_grad, ops::MatrixPowerGradOp); REGISTER_OPERATOR(matrix_power_grad, ops::MatrixPowerGradOp);
...@@ -16,77 +16,19 @@ limitations under the License. */ ...@@ -16,77 +16,19 @@ limitations under the License. */
#include <utility> #include <utility>
#include <vector> #include <vector>
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_version_registry.h" #include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/operators/strided_memcpy.h" #include "paddle/fluid/operators/strided_memcpy.h"
#include "paddle/fluid/operators/utils.h" #include "paddle/fluid/operators/utils.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/multiary.h"
#include "paddle/phi/kernels/funcs/blas/blas.h" #include "paddle/phi/kernels/funcs/blas/blas.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
using Tensor = framework::Tensor; using Tensor = framework::Tensor;
/**
* @brief compute the output shape and check the input shape valid or not
*/
inline framework::DDim ComputeAndCheckShape(
const bool is_runtime, const std::vector<framework::DDim>& inputs_dims) {
const size_t n = inputs_dims.size();
auto first_dim = inputs_dims[0];
bool is_vector = false;
framework::DDim out_dim;
PADDLE_ENFORCE_LT(
first_dim.size(), static_cast<size_t>(3),
platform::errors::InvalidArgument(
"multi_dot: the first input tensor must be 1D or 2D but got[%d]!",
static_cast<int>(first_dim.size())));
// If the first tensor is 1D of size n view it as a row vector (1, n)
if (first_dim.size() == 1) {
first_dim = phi::make_ddim({1, static_cast<int>(first_dim[0])});
is_vector = true;
}
auto last_dim = inputs_dims[n - 1];
PADDLE_ENFORCE_LT(
last_dim.size(), static_cast<size_t>(3),
platform::errors::InvalidArgument(
"the last input tensor of multi_dot must be 1D or 2D but got[%d]!",
static_cast<int>(first_dim.size())));
// If the last tensor is 1D of size n view it as a column vector (n, 1)
if (last_dim.size() == 1) {
last_dim = phi::make_ddim({static_cast<int>(last_dim[0]), 1});
out_dim = is_vector ? phi::make_ddim({1}) : phi::make_ddim({first_dim[0]});
} else {
out_dim = is_vector ? phi::make_ddim({last_dim[1]})
: phi::make_ddim({first_dim[0], last_dim[1]});
}
auto width = first_dim[1];
for (size_t i = 1; i < n - 1; i++) {
PADDLE_ENFORCE_EQ(inputs_dims[i].size(), static_cast<size_t>(2),
platform::errors::InvalidArgument(
"the input tensor of multi_dot op must be 2D."));
const auto& tmp_dim = inputs_dims[i];
PADDLE_ENFORCE_EQ(
tmp_dim[0], width,
platform::errors::InvalidArgument(
"the input matrix does not meet the multiplication requirements."));
width = tmp_dim[1];
}
PADDLE_ENFORCE_EQ(
last_dim[0], width,
platform::errors::InvalidArgument(
"the input matrix does not meet the multiplication requirements."));
return out_dim;
}
class MultiDotOpMaker : public framework::OpProtoAndCheckerMaker { class MultiDotOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
void Make() override { void Make() override {
...@@ -105,22 +47,6 @@ If the first argument is 1-D it is treated as a row vector. If the last argument ...@@ -105,22 +47,6 @@ If the first argument is 1-D it is treated as a row vector. If the last argument
class MultiDotOp : public framework::OperatorWithKernel { class MultiDotOp : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInputs("X"), "Input", "X", "multi_dot");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "multi_dot");
auto inputs_dims = ctx->GetInputsDim("X");
const size_t inputs_num = inputs_dims.size();
PADDLE_ENFORCE_GT(
inputs_num, static_cast<size_t>(1),
platform::errors::InvalidArgument(
"The number of input tensors in multi_dot op should > 1."));
auto out_dims = ComputeAndCheckShape(ctx->IsRuntime(), inputs_dims);
ctx->SetOutputDim("Out", out_dims);
ctx->ShareLoD("X", "Out");
}
}; };
class MultiDotOpGrad : public framework::OperatorWithKernel { class MultiDotOpGrad : public framework::OperatorWithKernel {
...@@ -171,9 +97,15 @@ class MultiDotOpDoubleGradMaker : public framework::SingleGradOpMaker<T> { ...@@ -171,9 +97,15 @@ class MultiDotOpDoubleGradMaker : public framework::SingleGradOpMaker<T> {
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
DECLARE_INFER_SHAPE_FUNCTOR(multi_dot, MultiDotInferShapeFunctor,
PD_INFER_META(phi::MultiDotInferMeta));
REGISTER_OPERATOR(multi_dot, ops::MultiDotOp, ops::MultiDotOpMaker, REGISTER_OPERATOR(multi_dot, ops::MultiDotOp, ops::MultiDotOpMaker,
ops::MultiDotOpGradMaker<paddle::framework::OpDesc>, ops::MultiDotOpGradMaker<paddle::framework::OpDesc>,
ops::MultiDotOpGradMaker<paddle::imperative::OpBase>); ops::MultiDotOpGradMaker<paddle::imperative::OpBase>,
MultiDotInferShapeFunctor);
REGISTER_OPERATOR(multi_dot_grad, ops::MultiDotOpGrad, REGISTER_OPERATOR(multi_dot_grad, ops::MultiDotOpGrad,
ops::MultiDotOpDoubleGradMaker<paddle::framework::OpDesc>, ops::MultiDotOpDoubleGradMaker<paddle::framework::OpDesc>,
ops::MultiDotOpDoubleGradMaker<paddle::imperative::OpBase>); ops::MultiDotOpDoubleGradMaker<paddle::imperative::OpBase>);
...@@ -13,7 +13,10 @@ See the License for the specific language governing permissions and ...@@ -13,7 +13,10 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include <string> #include <string>
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/unary.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -22,17 +25,6 @@ class ShapeOp : public framework::OperatorWithKernel { ...@@ -22,17 +25,6 @@ class ShapeOp : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE_EQ(ctx->HasInput("Input"), true,
platform::errors::InvalidArgument(
"Input (Input) of get_shape op should not be null."));
PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true,
platform::errors::InvalidArgument(
"Output (Out) of get_shape op should not be null."));
auto in_dim = ctx->GetInputDim("Input");
ctx->SetOutputDim("Out", {in_dim.size()});
}
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext &ctx) const override {
auto input_data_type = auto input_data_type =
...@@ -89,7 +81,12 @@ Return the shape of the input. ...@@ -89,7 +81,12 @@ Return the shape of the input.
namespace ops = paddle::operators; namespace ops = paddle::operators;
namespace plat = paddle::platform; namespace plat = paddle::platform;
DECLARE_INFER_SHAPE_FUNCTOR(shape, ShapeInferShapeFunctor,
PD_INFER_META(phi::ShapeInferMeta));
REGISTER_OPERATOR( REGISTER_OPERATOR(
shape, ops::ShapeOp, ops::ShapeOpMaker, shape, ops::ShapeOp, ops::ShapeOpMaker,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>, paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>); paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>,
ShapeInferShapeFunctor);
...@@ -369,6 +369,79 @@ void ConcatInferMeta(const std::vector<MetaTensor*>& x, ...@@ -369,6 +369,79 @@ void ConcatInferMeta(const std::vector<MetaTensor*>& x,
out->share_lod(*x.at(0)); out->share_lod(*x.at(0));
} }
void MultiDotInferMeta(const std::vector<MetaTensor*>& x, MetaTensor* out) {
auto inputs_dims = GetMetaTensorsDim(x);
const size_t inputs_num = inputs_dims.size();
PADDLE_ENFORCE_GT(
inputs_num,
static_cast<size_t>(1),
phi::errors::InvalidArgument(
"The number of input tensors in multi_dot op should > 1."));
const size_t n = inputs_dims.size();
auto first_dim = inputs_dims[0];
bool is_vector = false;
phi::DDim out_dim;
PADDLE_ENFORCE_LT(
first_dim.size(),
static_cast<size_t>(3),
phi::errors::InvalidArgument(
"multi_dot: the first input tensor must be 1D or 2D but got[%d]!",
static_cast<int>(first_dim.size())));
// If the first tensor is 1D of size n view it as a row vector (1, n)
if (first_dim.size() == 1) {
first_dim = phi::make_ddim({1, static_cast<int>(first_dim[0])});
is_vector = true;
}
auto last_dim = inputs_dims[n - 1];
PADDLE_ENFORCE_LT(
last_dim.size(),
static_cast<size_t>(3),
phi::errors::InvalidArgument(
"the last input tensor of multi_dot must be 1D or 2D but got[%d]!",
static_cast<int>(first_dim.size())));
// If the last tensor is 1D of size n view it as a column vector (n, 1)
if (last_dim.size() == 1) {
last_dim = phi::make_ddim({static_cast<int>(last_dim[0]), 1});
out_dim = is_vector ? phi::make_ddim({1}) : phi::make_ddim({first_dim[0]});
} else {
out_dim = is_vector ? phi::make_ddim({last_dim[1]})
: phi::make_ddim({first_dim[0], last_dim[1]});
}
auto width = first_dim[1];
for (size_t i = 1; i < n - 1; i++) {
PADDLE_ENFORCE_EQ(inputs_dims[i].size(),
static_cast<size_t>(2),
phi::errors::InvalidArgument(
"the input tensor of multi_dot op must be 2D."));
const auto& tmp_dim = inputs_dims[i];
PADDLE_ENFORCE_EQ(
tmp_dim[0],
width,
phi::errors::InvalidArgument(
"the input matrix does not meet the multiplication requirements."));
width = tmp_dim[1];
}
PADDLE_ENFORCE_EQ(
last_dim[0],
width,
phi::errors::InvalidArgument(
"the input matrix does not meet the multiplication requirements."));
out->set_dims(out_dim);
out->set_dtype(x.at(0)->dtype());
out->share_lod(*x.at(0));
}
void PsroiPoolInferMeta(const MetaTensor& x, void PsroiPoolInferMeta(const MetaTensor& x,
const MetaTensor& rois, const MetaTensor& rois,
paddle::optional<const MetaTensor&> rois_num, paddle::optional<const MetaTensor&> rois_num,
......
...@@ -70,6 +70,8 @@ void ConcatInferMeta(const std::vector<MetaTensor*>& x, ...@@ -70,6 +70,8 @@ void ConcatInferMeta(const std::vector<MetaTensor*>& x,
MetaTensor* out, MetaTensor* out,
MetaConfig config = MetaConfig()); MetaConfig config = MetaConfig());
void MultiDotInferMeta(const std::vector<MetaTensor*>& x, MetaTensor* out);
void PsroiPoolInferMeta(const MetaTensor& x, void PsroiPoolInferMeta(const MetaTensor& x,
const MetaTensor& rois, const MetaTensor& rois,
paddle::optional<const MetaTensor&> rois_num, paddle::optional<const MetaTensor&> rois_num,
......
...@@ -554,6 +554,28 @@ void IsfiniteInferMeta(const MetaTensor& x, MetaTensor* out) { ...@@ -554,6 +554,28 @@ void IsfiniteInferMeta(const MetaTensor& x, MetaTensor* out) {
out->set_dtype(DataType::BOOL); out->set_dtype(DataType::BOOL);
} }
void MatrixPowerInferMeta(const MetaTensor& x, int n, MetaTensor* out) {
auto dims = x.dims();
auto n_dim = dims.size();
PADDLE_ENFORCE_GE(n_dim,
2,
phi::errors::InvalidArgument(
"The Input(X) should have at least 2 dimensions. But "
"received a %d dimension tensor.",
n_dim));
PADDLE_ENFORCE_EQ(dims[n_dim - 2],
dims[n_dim - 1],
phi::errors::InvalidArgument(
"The inner-most 2 dimensions of Input(X) all should "
"be square matrices "
"But received X's shape[-2] = %d and shape[-1] = %d.",
dims[n_dim - 2],
dims[n_dim - 1]));
out->set_dims(dims);
out->share_lod(x);
out->set_dtype(x.dtype());
}
void MaxPoolWithIndexInferMeta(const MetaTensor& x, void MaxPoolWithIndexInferMeta(const MetaTensor& x,
const std::vector<int>& kernel_size, const std::vector<int>& kernel_size,
const std::vector<int>& strides, const std::vector<int>& strides,
...@@ -994,6 +1016,12 @@ void ReshapeWithXShapeInferMeta(const MetaTensor& x, ...@@ -994,6 +1016,12 @@ void ReshapeWithXShapeInferMeta(const MetaTensor& x,
ReshapeInferMeta(x, shape, out, config); ReshapeInferMeta(x, shape, out, config);
} }
void ShapeInferMeta(const MetaTensor& input, MetaTensor* out) {
auto in_dim = input.dims();
out->set_dims(phi::make_ddim({in_dim.size()}));
out->set_dtype(DataType::INT32);
}
void ShardIndexInferMeta(const MetaTensor& in, void ShardIndexInferMeta(const MetaTensor& in,
int index_num, int index_num,
int nshards, int nshards,
......
...@@ -98,6 +98,8 @@ void IsEmptyInferMeta(const MetaTensor& x, MetaTensor* out); ...@@ -98,6 +98,8 @@ void IsEmptyInferMeta(const MetaTensor& x, MetaTensor* out);
void IsfiniteInferMeta(const MetaTensor& input, MetaTensor* out); void IsfiniteInferMeta(const MetaTensor& input, MetaTensor* out);
void MatrixPowerInferMeta(const MetaTensor& x, int n, MetaTensor* out);
void MaxPoolWithIndexInferMeta(const MetaTensor& x, void MaxPoolWithIndexInferMeta(const MetaTensor& x,
const std::vector<int>& kernel_size, const std::vector<int>& kernel_size,
const std::vector<int>& strides, const std::vector<int>& strides,
...@@ -162,6 +164,8 @@ void ReshapeWithXShapeInferMeta(const MetaTensor& x, ...@@ -162,6 +164,8 @@ void ReshapeWithXShapeInferMeta(const MetaTensor& x,
MetaTensor* out, MetaTensor* out,
MetaConfig config = MetaConfig()); MetaConfig config = MetaConfig());
void ShapeInferMeta(const MetaTensor& input, MetaTensor* out);
void ShardIndexInferMeta(const MetaTensor& in, void ShardIndexInferMeta(const MetaTensor& in,
int index_num, int index_num,
int nshards, int nshards,
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/shape_kernel.h"
#include "paddle/phi/kernels/impl/shape_kernel_impl.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
PD_REGISTER_KERNEL(shape,
CPU,
ALL_LAYOUT,
phi::ShapeKernel,
bool,
int,
int8_t,
uint8_t,
int64_t,
float,
double,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
template <typename T, typename Context>
void ShapeKernel(const Context& ctx,
const DenseTensor& input,
DenseTensor* out) {
auto in_var = &input;
phi::DDim in_dims;
in_dims = in_var->dims();
auto out_t = out;
out_t->Resize({in_dims.size()});
auto out_data = ctx.template HostAlloc<int32_t>(out_t);
for (int i = 0; i < in_dims.size(); ++i) {
out_data[i] = in_dims[i];
}
}
} // namespace phi
...@@ -17,6 +17,7 @@ limitations under the License. */ ...@@ -17,6 +17,7 @@ limitations under the License. */
#include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/float16.h" #include "paddle/phi/common/float16.h"
#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/shape_kernel.h"
namespace phi { namespace phi {
namespace sr { namespace sr {
...@@ -25,15 +26,7 @@ template <typename T, typename Context> ...@@ -25,15 +26,7 @@ template <typename T, typename Context>
void ShapeKernel(const Context& ctx, void ShapeKernel(const Context& ctx,
const SelectedRows& input, const SelectedRows& input,
DenseTensor* out) { DenseTensor* out) {
auto in_var = input; phi::ShapeKernel<T, Context>(ctx, input.value(), out);
phi::DDim in_dims;
in_dims = in_var.value().dims();
auto out_t = out;
out_t->Resize({in_dims.size()});
auto out_data = ctx.template HostAlloc<int32_t>(out_t);
for (int i = 0; i < in_dims.size(); ++i) {
out_data[i] = in_dims[i];
}
} }
} // namespace sr } // namespace sr
......
...@@ -13,12 +13,43 @@ See the License for the specific language governing permissions and ...@@ -13,12 +13,43 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/phi/kernels/shape_kernel.h" #include "paddle/phi/kernels/shape_kernel.h"
#include "paddle/phi/backends/all_context.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/float16.h"
#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/shape_kernel_impl.h"
namespace phi {
template <typename T, typename Context>
void ShapeKernel(const Context& ctx,
const DenseTensor& input,
DenseTensor* out) {
auto in_var = &input;
phi::DDim in_dims;
in_dims = in_var->dims();
auto out_t = out;
out_t->Resize({in_dims.size()});
auto out_data = ctx.template HostAlloc<int32_t>(out_t);
for (int i = 0; i < in_dims.size(); ++i) {
out_data[i] = in_dims[i];
}
}
} // namespace phi
PD_REGISTER_KERNEL(shape,
CPU,
ALL_LAYOUT,
phi::ShapeKernel,
bool,
int,
int8_t,
uint8_t,
int64_t,
float,
double,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
PD_REGISTER_KERNEL(shape, PD_REGISTER_KERNEL(shape,
GPU, GPU,
ALL_LAYOUT, ALL_LAYOUT,
...@@ -33,3 +64,4 @@ PD_REGISTER_KERNEL(shape, ...@@ -33,3 +64,4 @@ PD_REGISTER_KERNEL(shape,
phi::dtype::complex<float>, phi::dtype::complex<float>,
phi::dtype::complex<double>, phi::dtype::complex<double>,
phi::dtype::float16) {} phi::dtype::float16) {}
#endif
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册