未验证 提交 068f48d8 编写于 作者: X xiongkun 提交者: GitHub

[ Phi Kernel ] Transfer as_real to phi. (#44263)

* transfer as_real to phi

* fix erros

* blocking: True -> False
上级 13d01e6e
......@@ -20,7 +20,9 @@
#include <vector>
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/phi/infermeta/unary.h"
namespace paddle {
namespace operators {
......@@ -94,17 +96,6 @@ class AsRealOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "as_real");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "as_real");
auto out_dims_v = phi::vectorize(ctx->GetInputDim("X"));
out_dims_v.push_back(2);
const framework::DDim out_dims = phi::make_ddim(out_dims_v);
ctx->SetOutputDim("Out", out_dims);
ctx->ShareLoD("X", /*->*/ "Out");
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
......@@ -148,6 +139,9 @@ class AsRealGradMaker : public framework::SingleGradOpMaker<T> {
} // namespace paddle
namespace ops = paddle::operators;
DECLARE_INFER_SHAPE_FUNCTOR(as_real,
AsRealInferShapeFunctor,
PD_INFER_META(phi::AsRealInferMeta));
REGISTER_OPERATOR(as_complex,
ops::AsComplexOp,
......@@ -158,13 +152,10 @@ REGISTER_OPERATOR(as_complex,
REGISTER_OPERATOR(as_real,
ops::AsRealOp,
ops::AsRealOpMaker,
AsRealInferShapeFunctor,
ops::AsRealGradMaker<paddle::framework::OpDesc>,
ops::AsRealGradMaker<paddle::imperative::OpBase>);
REGISTER_OP_CPU_KERNEL(as_complex,
ops::AsComplexKernel<phi::CPUContext, float>,
ops::AsComplexKernel<phi::CPUContext, double>);
REGISTER_OP_CPU_KERNEL(as_real,
ops::AsRealKernel<phi::CPUContext, float>,
ops::AsRealKernel<phi::CPUContext, double>);
......@@ -22,8 +22,3 @@ REGISTER_OP_CUDA_KERNEL(
as_complex,
ops::AsComplexKernel<paddle::platform::CUDADeviceContext, float>,
ops::AsComplexKernel<paddle::platform::CUDADeviceContext, double>);
REGISTER_OP_CUDA_KERNEL(
as_real,
ops::AsRealKernel<paddle::platform::CUDADeviceContext, float>,
ops::AsRealKernel<paddle::platform::CUDADeviceContext, double>);
......@@ -41,20 +41,5 @@ class AsComplexKernel : public framework::OpKernel<T> {
}
};
template <typename DeviceContext, typename T>
class AsRealKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
const auto* x = context.Input<framework::LoDTensor>("X");
auto* out = context.Output<framework::LoDTensor>("Out");
out->mutable_data<T>(context.GetPlace());
const framework::DDim out_dims_original = out->dims();
framework::TensorCopy(*x, context.GetPlace(), out);
out->Resize(out_dims_original); // restored the shape
out->mutable_data<T>(context.GetPlace()); // restore the dtype
}
};
} // namespace operators
} // namespace paddle
......@@ -167,6 +167,15 @@
func : argsort
backward : argsort_grad
- api : as_real
args : (Tensor x)
output : Tensor
infer_meta :
func : AsRealInferMeta
kernel :
func : as_real
# backward : as_complex
# asin
- api : asin
args : (Tensor x)
......
......@@ -148,6 +148,14 @@ void ArgsortInferMeta(const MetaTensor& input,
indices->share_lod(input);
}
void AsRealInferMeta(const MetaTensor& input, MetaTensor* output) {
auto out_dims_v = phi::vectorize(input.dims());
out_dims_v.push_back(2);
auto out_dims = phi::make_ddim(out_dims_v);
output->set_dims(out_dims);
output->share_lod(input);
}
void BatchSizeLikeInferMeta(const MetaTensor& x,
const std::vector<int>& shape,
int x_batch_size_dim,
......
......@@ -48,6 +48,8 @@ void ArgsortInferMeta(const MetaTensor& input,
MetaTensor* output,
MetaTensor* indices);
void AsRealInferMeta(const MetaTensor& input, MetaTensor* output);
void BatchSizeLikeInferMeta(const MetaTensor& x,
const std::vector<int>& shape,
int x_batch_size_dim,
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
template <typename T, typename Context>
void AsRealKernel(const Context& dev_ctx,
const DenseTensor& x,
DenseTensor* out);
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/as_real_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/as_real_impl.h"
PD_REGISTER_KERNEL(as_real, CPU, ALL_LAYOUT, phi::AsRealKernel, float, double) {
}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/as_real_kernel.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/as_real_impl.h"
PD_REGISTER_KERNEL(as_real, GPU, ALL_LAYOUT, phi::AsRealKernel, float, double) {
}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/kernels/as_real_kernel.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/tensor_utils.h"
#include "paddle/phi/kernels/funcs/for_range.h"
namespace phi {
template <typename T, typename Context>
void AsRealKernel(const Context& ctx, const DenseTensor& x, DenseTensor* out) {
ctx.template Alloc<T>(out);
auto out_dims_original = out->dims();
Copy(ctx, x, ctx.GetPlace(), false, out);
out->Resize(out_dims_original); // restored the shape.
out->set_type(
paddle::experimental::CppTypeToDataType<T>::Type()); // restored the
// dtype.
}
} // namespace phi
......@@ -67,6 +67,7 @@ class TestViewAsRealOp(OpTest):
out_ref = ref_view_as_real(x)
self.inputs = {'X': x}
self.outputs = {'Out': out_ref}
self.python_api = paddle.as_real
self.out_grad = np.ones([10, 10, 2], dtype="float64")
def test_check_output(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册