未验证 提交 639675de 编写于 作者: 0 0x45f 提交者: GitHub

move eye、size、erfinv、pixel_shuffle OP to phi (#39712)

* move eye OP to pten

* move size OP to pten

* merge develop

* fix merge

* move files

* move erfinv OP to phi

* remove comment

* move pixel_shuffle OP to phi

* remove comment

* fix PT_REGISTER

* fix NPU

* fix CR

* remove size_sig.cc for PR-CI-Coverage
上级 4fe465cb
......@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/erfinv_op.h"
#include "paddle/fluid/framework/op_registry.h"
namespace paddle {
namespace operators {
......@@ -85,16 +85,3 @@ REGISTER_OPERATOR(
paddle::operators::ErfinvInplaceInferer);
REGISTER_OPERATOR(erfinv_grad, paddle::operators::ErfinvGradOp);
REGISTER_OP_CPU_KERNEL(
erfinv,
paddle::operators::ErfinvKernel<paddle::platform::CPUDeviceContext, float>,
paddle::operators::ErfinvKernel<paddle::platform::CPUDeviceContext,
double>);
REGISTER_OP_CPU_KERNEL(
erfinv_grad,
paddle::operators::ErfinvGradKernel<paddle::platform::CPUDeviceContext,
float>,
paddle::operators::ErfinvGradKernel<paddle::platform::CPUDeviceContext,
double>);
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#ifndef _USE_MATH_DEFINES
#define _USE_MATH_DEFINES // use M_2_SQRTPI on Windows
#endif
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
namespace paddle {
namespace operators {
// ndtri(x * 0.5 + 0.5) / sqrt(2)
template <typename DeviceContext, typename T>
class ErfinvKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto in = ctx.Input<framework::Tensor>("X");
auto out = ctx.Output<framework::Tensor>("Out");
out->mutable_data<T>(ctx.GetPlace());
auto eigen_in = framework::EigenVector<T>::Flatten(*in);
auto eigen_out = framework::EigenVector<T>::Flatten(*out);
auto& place = *ctx.template device_context<DeviceContext>().eigen_device();
constexpr T half = static_cast<T>(0.5);
constexpr T half_sqrt = static_cast<T>(M_SQRT1_2);
eigen_out.device(place) = (eigen_in * half + half).ndtri() * half_sqrt;
}
};
// sqrt(pi) / 2 * exp(square(out)) * grad
template <typename DeviceContext, typename T>
class ErfinvGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto out = ctx.Input<framework::Tensor>("Out");
auto dout = ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
auto dx = ctx.Output<framework::Tensor>(framework::GradVarName("X"));
dx->mutable_data<T>(ctx.GetPlace());
auto eigen_out = framework::EigenVector<T>::Flatten(*out);
auto eigen_dout = framework::EigenVector<T>::Flatten(*dout);
auto eigen_dx = framework::EigenVector<T>::Flatten(*dx);
auto& place = *ctx.template device_context<DeviceContext>().eigen_device();
constexpr T half_sqrt_pi = static_cast<T>(1 / M_2_SQRTPI);
eigen_dx.device(place) =
half_sqrt_pi * eigen_dout * eigen_out.square().exp();
}
};
} // namespace operators
} // namespace paddle
......@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/eye_op.h"
#include "paddle/fluid/framework/op_registry.h"
namespace paddle {
namespace operators {
......@@ -82,14 +82,8 @@ Return an identity tensor whose shape is [num_rows, num_columns].
} // namespace paddle
namespace ops = paddle::operators;
using CPU = paddle::platform::CPUDeviceContext;
REGISTER_OPERATOR(
eye, ops::EyeOp, ops::EyeOpMaker, ops::EyeOpVarTypeInference,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OP_CPU_KERNEL(eye, ops::EyeKernel<CPU, float>,
ops::EyeKernel<CPU, double>,
ops::EyeKernel<CPU, int64_t>, ops::EyeKernel<CPU, int>,
ops::EyeKernel<CPU, paddle::platform::float16>);
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/eye_op.h"
namespace ops = paddle::operators;
namespace plf = paddle::platform;
REGISTER_OP_CUDA_KERNEL(
eye, ops::EyeKernel<plf::CUDADeviceContext, float>,
ops::EyeKernel<plf::CUDADeviceContext, double>,
ops::EyeKernel<plf::CUDADeviceContext, int64_t>,
ops::EyeKernel<plf::CUDADeviceContext, int>,
ops::EyeKernel<plf::CUDADeviceContext, paddle::platform::float16>);
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <algorithm>
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/for_range.h"
#include "paddle/phi/kernels/funcs/math_function.h"
namespace paddle {
namespace operators {
template <typename T>
struct EyeFunctor {
EyeFunctor(int64_t num_columns, T* output)
: num_columns_(num_columns), output_(output) {}
HOSTDEVICE void operator()(size_t idx) const {
output_[idx * num_columns_ + idx] = static_cast<T>(1);
}
int64_t num_columns_;
T* output_;
};
template <typename DeviceContext, typename T>
class EyeKernel : public framework::OpKernel<T> {
public:
void Compute(const paddle::framework::ExecutionContext& ctx) const override {
auto num_rows = ctx.Attr<int64_t>("num_rows");
auto num_columns = ctx.Attr<int64_t>("num_columns");
if (num_columns == -1) num_columns = num_rows;
auto* out_tensor = ctx.Output<framework::Tensor>("Out");
T* out_data = out_tensor->mutable_data<T>(ctx.GetPlace());
phi::funcs::SetConstant<DeviceContext, T> set_zero;
auto& dev_ctx = ctx.template device_context<DeviceContext>();
set_zero(dev_ctx, out_tensor, static_cast<T>(0));
int64_t num_eyes = (std::min)(num_rows, num_columns);
platform::ForRange<DeviceContext> for_range(dev_ctx, num_eyes);
EyeFunctor<T> functor(num_columns, out_data);
for_range(functor);
}
};
} // namespace operators
} // namespace paddle
......@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/eye_op.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/device/npu/npu_op_runner.h"
namespace paddle {
......
......@@ -9,8 +9,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/pixel_shuffle_op.h"
#include <memory>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_version_registry.h"
namespace paddle {
......@@ -177,16 +177,6 @@ REGISTER_OPERATOR(pixel_shuffle, ops::PixelShuffleOp, ops::PixelShuffleOpMaker,
REGISTER_OPERATOR(pixel_shuffle_grad, ops::PixelShuffleGradOp);
REGISTER_OP_CPU_KERNEL(
pixel_shuffle,
ops::PixelShuffleOpKernel<paddle::platform::CPUDeviceContext, float>,
ops::PixelShuffleOpKernel<paddle::platform::CPUDeviceContext, double>);
REGISTER_OP_CPU_KERNEL(
pixel_shuffle_grad,
ops::PixelShuffleGradOpKernel<paddle::platform::CPUDeviceContext, float>,
ops::PixelShuffleGradOpKernel<paddle::platform::CPUDeviceContext, double>);
REGISTER_OP_VERSION(pixel_shuffle)
.AddCheckpoint(
R"ROC(
......
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/pixel_shuffle_op.h"
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(
pixel_shuffle, ops::PixelShuffleOpKernel<plat::CUDADeviceContext, float>,
ops::PixelShuffleOpKernel<plat::CUDADeviceContext, double>);
REGISTER_OP_CUDA_KERNEL(
pixel_shuffle_grad,
ops::PixelShuffleGradOpKernel<plat::CUDADeviceContext, float>,
ops::PixelShuffleGradOpKernel<plat::CUDADeviceContext, double>);
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <algorithm>
#include <string>
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/kernels/funcs/math_function.h"
namespace paddle {
namespace operators {
template <typename DeviceContext, typename T>
class PixelShuffleOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* in = ctx.Input<framework::Tensor>("X");
auto* out = ctx.Output<framework::Tensor>("Out");
out->mutable_data<T>(ctx.GetPlace());
int factor = ctx.Attr<int>("upscale_factor");
std::string data_format = ctx.Attr<std::string>("data_format");
bool channel_last = (data_format == "NHWC");
auto in_dims = in->dims();
auto o_dims = out->dims();
framework::Tensor t;
t.ShareDataWith(*in);
if (!channel_last) {
t.Resize({in_dims[0], o_dims[1], factor, factor, in_dims[2], in_dims[3]});
} else {
t.Resize({in_dims[0], in_dims[1], in_dims[2], o_dims[3], factor, factor});
}
std::vector<int> axis = {0, 1, 4, 2, 5, 3};
framework::Tensor o;
o.ShareDataWith(*out);
if (!channel_last) {
o.Resize({in_dims[0], o_dims[1], in_dims[2], factor, in_dims[3], factor});
} else {
o.Resize({in_dims[0], in_dims[1], factor, in_dims[2], factor, o_dims[3]});
}
phi::funcs::Transpose<DeviceContext, T, 6> trans;
auto& dev_ctx = ctx.template device_context<DeviceContext>();
trans(dev_ctx, t, &o, axis);
out->Resize(o_dims);
}
};
template <typename DeviceContext, typename T>
class PixelShuffleGradOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* dout = ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
auto* dx = ctx.Output<framework::Tensor>(framework::GradVarName("X"));
dx->mutable_data<T>(ctx.GetPlace());
int factor = ctx.Attr<int>("upscale_factor");
std::string data_format = ctx.Attr<std::string>("data_format");
bool channel_last = (data_format == "NHWC");
auto do_dims = dout->dims();
auto dx_dims = dx->dims();
framework::Tensor t;
t.ShareDataWith(*dout);
if (!channel_last) {
t.Resize(
{do_dims[0], do_dims[1], dx_dims[2], factor, dx_dims[3], factor});
} else {
t.Resize(
{do_dims[0], dx_dims[1], factor, dx_dims[2], factor, do_dims[3]});
}
std::vector<int> axis = {0, 1, 3, 5, 2, 4};
framework::Tensor o;
o.ShareDataWith(*dx);
if (!channel_last) {
o.Resize(
{do_dims[0], do_dims[1], factor, factor, dx_dims[2], dx_dims[3]});
} else {
o.Resize(
{do_dims[0], dx_dims[1], dx_dims[2], do_dims[3], factor, factor});
}
phi::funcs::Transpose<DeviceContext, T, 6> trans;
auto& dev_ctx = ctx.template device_context<DeviceContext>();
trans(dev_ctx, t, &o, axis);
dx->Resize(dx_dims);
}
};
} // namespace operators
} // namespace paddle
......@@ -12,7 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/size_op.h"
#include "paddle/fluid/framework/op_registry.h"
namespace paddle {
......@@ -53,7 +52,3 @@ REGISTER_OPERATOR(
size, ops::SizeOp, ops::SizeOpMaker,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OP_CPU_KERNEL(size, ops::SizeKernel<int>, ops::SizeKernel<int64_t>,
ops::SizeKernel<paddle::platform::float16>,
ops::SizeKernel<float>, ops::SizeKernel<double>,
ops::SizeKernel<bool>);
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/size_op.h"
REGISTER_OP_CUDA_KERNEL(
size, paddle::operators::SizeKernel<int>,
paddle::operators::SizeKernel<int64_t>,
paddle::operators::SizeKernel<paddle::platform::float16>,
paddle::operators::SizeKernel<float>, paddle::operators::SizeKernel<bool>,
paddle::operators::SizeKernel<double>);
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/framework/op_registry.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename T>
class SizeKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* in_t = ctx.Input<Tensor>("Input");
auto* out_t = ctx.Output<Tensor>("Out");
auto place = ctx.GetPlace();
auto out_data = out_t->mutable_data<int64_t>(place);
auto cpu_place = platform::CPUPlace();
if (place == cpu_place) {
out_data[0] = in_t->numel();
} else {
Tensor cpu_tensor;
auto cpu_data =
cpu_tensor.mutable_data<int64_t>(out_t->dims(), cpu_place);
cpu_data[0] = in_t->numel();
paddle::framework::TensorCopy(cpu_tensor, place, out_t);
}
}
};
} // namespace operators
} // namespace paddle
......@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/mul_op.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/device/npu/npu_op_runner.h"
namespace paddle {
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/erfinv_grad_kernel.h"
#include "paddle/phi/kernels/impl/erfinv_grad_kernel_impl.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
PD_REGISTER_KERNEL(
erfinv_grad, CPU, ALL_LAYOUT, phi::ErfinvGradKernel, float, double) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/erfinv_kernel.h"
#include "paddle/phi/kernels/impl/erfinv_kernel_impl.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
PD_REGISTER_KERNEL(erfinv, CPU, ALL_LAYOUT, phi::ErfinvKernel, float, double) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/eye_kernel.h"
#include "paddle/phi/kernels/impl/eye_kernel_impl.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
PD_REGISTER_KERNEL(eye,
CPU,
ALL_LAYOUT,
phi::EyeKernel,
float,
double,
int64_t,
int,
phi::dtype::float16) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/pixel_shuffle_grad_kernel.h"
#include "paddle/phi/kernels/impl/pixel_shuffle_grad_kernel_impl.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
PD_REGISTER_KERNEL(pixel_shuffle_grad,
CPU,
ALL_LAYOUT,
phi::PixelShuffleGradKernel,
float,
double) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/pixel_shuffle_kernel.h"
#include "paddle/phi/kernels/impl/pixel_shuffle_kernel_impl.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
PD_REGISTER_KERNEL(
pixel_shuffle, CPU, ALL_LAYOUT, phi::PixelShuffleKernel, float, double) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/size_kernel.h"
#include "paddle/phi/kernels/impl/size_kernel_impl.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
PD_REGISTER_KERNEL(size,
CPU,
ALL_LAYOUT,
phi::SizeKernel,
int,
int64_t,
phi::dtype::float16,
float,
double,
bool) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
template <typename T, typename Context>
void ErfinvGradKernel(const Context& ctx,
const DenseTensor& out,
const DenseTensor& out_grad,
DenseTensor* x_grad);
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
template <typename T, typename Context>
void ErfinvKernel(const Context& ctx, const DenseTensor& x, DenseTensor* out);
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
template <typename T, typename Context>
void EyeKernel(const Context& ctx,
int64_t num_rows,
int64_t num_columns,
int dtype,
DenseTensor* out);
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/erfinv_grad_kernel.h"
#include "paddle/phi/kernels/impl/erfinv_grad_kernel_impl.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
PD_REGISTER_KERNEL(
erfinv_grad, GPU, ALL_LAYOUT, phi::ErfinvGradKernel, float, double) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/erfinv_kernel.h"
#include "paddle/phi/kernels/impl/erfinv_kernel_impl.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
PD_REGISTER_KERNEL(erfinv, GPU, ALL_LAYOUT, phi::ErfinvKernel, float, double) {}
......@@ -12,17 +12,18 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/erfinv_op.h"
#include "paddle/phi/kernels/eye_kernel.h"
#include "paddle/phi/kernels/impl/eye_kernel_impl.h"
REGISTER_OP_CUDA_KERNEL(
erfinv,
paddle::operators::ErfinvKernel<paddle::platform::CUDADeviceContext, float>,
paddle::operators::ErfinvKernel<paddle::platform::CUDADeviceContext,
double>);
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
REGISTER_OP_CUDA_KERNEL(
erfinv_grad,
paddle::operators::ErfinvGradKernel<paddle::platform::CUDADeviceContext,
float>,
paddle::operators::ErfinvGradKernel<paddle::platform::CUDADeviceContext,
double>);
PD_REGISTER_KERNEL(eye,
GPU,
ALL_LAYOUT,
phi::EyeKernel,
float,
double,
int64_t,
int,
phi::dtype::float16) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/impl/pixel_shuffle_grad_kernel_impl.h"
#include "paddle/phi/kernels/pixel_shuffle_grad_kernel.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
PD_REGISTER_KERNEL(pixel_shuffle_grad,
GPU,
ALL_LAYOUT,
phi::PixelShuffleGradKernel,
float,
double) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/impl/pixel_shuffle_kernel_impl.h"
#include "paddle/phi/kernels/pixel_shuffle_kernel.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
PD_REGISTER_KERNEL(
pixel_shuffle, GPU, ALL_LAYOUT, phi::PixelShuffleKernel, float, double) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/impl/size_kernel_impl.h"
#include "paddle/phi/kernels/size_kernel.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
PD_REGISTER_KERNEL(size,
GPU,
ALL_LAYOUT,
phi::SizeKernel,
int,
int64_t,
phi::dtype::float16,
float,
double,
bool) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#ifndef _USE_MATH_DEFINES
#define _USE_MATH_DEFINES // use M_2_SQRTPI on Windows
#endif
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/kernels/funcs/eigen/common.h"
namespace phi {
template <typename T, typename Context>
void ErfinvGradKernel(const Context& ctx,
const DenseTensor& out,
const DenseTensor& out_grad,
DenseTensor* x_grad) {
ctx.template Alloc<T>(x_grad);
auto eigen_out = EigenVector<T>::Flatten(out);
auto eigen_dout = EigenVector<T>::Flatten(out_grad);
auto eigen_dx = EigenVector<T>::Flatten(*x_grad);
auto& place = *ctx.eigen_device();
constexpr T half_sqrt_pi = static_cast<T>(1 / M_2_SQRTPI);
eigen_dx.device(place) = half_sqrt_pi * eigen_dout * eigen_out.square().exp();
}
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#ifndef _USE_MATH_DEFINES
#define _USE_MATH_DEFINES // use M_2_SQRTPI on Windows
#endif
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/kernels/funcs/eigen/common.h"
namespace phi {
template <typename T, typename Context>
void ErfinvKernel(const Context& ctx, const DenseTensor& x, DenseTensor* out) {
ctx.template Alloc<T>(out);
auto eigen_in = EigenVector<T>::Flatten(x);
auto eigen_out = EigenVector<T>::Flatten(*out);
auto& place = *ctx.eigen_device();
constexpr T half = static_cast<T>(0.5);
constexpr T half_sqrt = static_cast<T>(M_SQRT1_2);
eigen_out.device(place) = (eigen_in * half + half).ndtri() * half_sqrt;
}
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/platform/for_range.h"
#include "paddle/phi/kernels/funcs/math_function.h"
namespace phi {
template <typename T>
struct EyeFunctor {
EyeFunctor(int64_t num_columns, T* output)
: num_columns_(num_columns), output_(output) {}
HOSTDEVICE void operator()(size_t idx) const {
output_[idx * num_columns_ + idx] = static_cast<T>(1);
}
int64_t num_columns_;
T* output_;
};
template <typename T, typename Context>
void EyeKernel(const Context& ctx,
int64_t num_rows,
int64_t num_columns,
int dtype,
DenseTensor* out) {
auto num = num_columns;
if (num == -1) {
num = num_rows;
}
T* out_data = ctx.template Alloc<T>(out);
phi::funcs::SetConstant<Context, T> set_zero;
set_zero(ctx, out, static_cast<T>(0));
int64_t num_eyes = (std::min)(num_rows, num);
paddle::platform::ForRange<Context> for_range(ctx, num_eyes);
EyeFunctor<T> functor(num, out_data);
for_range(functor);
}
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <algorithm>
#include <string>
#include <vector>
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/kernels/funcs/math_function.h"
namespace phi {
template <typename T, typename Context>
void PixelShuffleGradKernel(const Context& ctx,
const DenseTensor& out_grad,
int upscale_factor,
const std::string& data_format,
DenseTensor* x_grad) {
auto* dout = &out_grad;
auto* dx = x_grad;
ctx.template Alloc<T>(dx);
int factor = upscale_factor;
bool channel_last = (data_format == "NHWC");
auto do_dims = dout->dims();
auto dx_dims = dx->dims();
DenseTensor t(*dout);
if (!channel_last) {
t.Resize({do_dims[0], do_dims[1], dx_dims[2], factor, dx_dims[3], factor});
} else {
t.Resize({do_dims[0], dx_dims[1], factor, dx_dims[2], factor, do_dims[3]});
}
std::vector<int> axis = {0, 1, 3, 5, 2, 4};
DenseTensor o(*dx);
if (!channel_last) {
o.Resize({do_dims[0], do_dims[1], factor, factor, dx_dims[2], dx_dims[3]});
} else {
o.Resize({do_dims[0], dx_dims[1], dx_dims[2], do_dims[3], factor, factor});
}
phi::funcs::Transpose<Context, T, 6> trans;
trans(ctx, t, &o, axis);
dx->Resize(dx_dims);
}
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <algorithm>
#include <string>
#include <vector>
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/kernels/funcs/math_function.h"
namespace phi {
template <typename T, typename Context>
void PixelShuffleKernel(const Context& ctx,
const DenseTensor& x,
int upscale_factor,
const std::string& data_format,
DenseTensor* out) {
auto* in = &x;
ctx.template Alloc<T>(out);
int factor = upscale_factor;
bool channel_last = (data_format == "NHWC");
auto in_dims = in->dims();
auto o_dims = out->dims();
DenseTensor t(*in);
if (!channel_last) {
t.Resize({in_dims[0], o_dims[1], factor, factor, in_dims[2], in_dims[3]});
} else {
t.Resize({in_dims[0], in_dims[1], in_dims[2], o_dims[3], factor, factor});
}
std::vector<int> axis = {0, 1, 4, 2, 5, 3};
DenseTensor o(*out);
if (!channel_last) {
o.Resize({in_dims[0], o_dims[1], in_dims[2], factor, in_dims[3], factor});
} else {
o.Resize({in_dims[0], in_dims[1], factor, in_dims[2], factor, o_dims[3]});
}
phi::funcs::Transpose<Context, T, 6> trans;
trans(ctx, t, &o, axis);
out->Resize(o_dims);
}
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/kernels/copy_kernel.h"
namespace phi {
template <typename T, typename Context>
void SizeKernel(const Context& ctx,
const DenseTensor& input,
DenseTensor* out) {
auto place = ctx.GetPlace();
auto out_data = ctx.template Alloc<int64_t>(out);
auto cpu_place = phi::CPUPlace();
if (place == cpu_place) {
out_data[0] = input.numel();
} else {
DenseTensor cpu_tensor;
cpu_tensor.Resize(out->dims());
auto cpu_data = ctx.template HostAlloc<int64_t>(&cpu_tensor);
cpu_data[0] = input.numel();
phi::Copy(ctx, cpu_tensor, false, out);
}
}
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
template <typename T, typename Context>
void PixelShuffleGradKernel(const Context& ctx,
const DenseTensor& out_grad,
int upscale_factor,
const std::string& data_format,
DenseTensor* x_grad);
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
template <typename T, typename Context>
void PixelShuffleKernel(const Context& ctx,
const DenseTensor& x,
int upscale_factor,
const std::string& data_format,
DenseTensor* out);
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
template <typename T, typename Context>
void SizeKernel(const Context& ctx, const DenseTensor& input, DenseTensor* out);
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/core/compat/op_utils.h"
namespace phi {
KernelSignature ErfinvGradOpArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature(
"erfinv_grad", {"Out", GradVarName("Out")}, {}, {GradVarName("X")});
}
} // namespace phi
PD_REGISTER_ARG_MAPPING_FN(erfinv_grad, phi::ErfinvGradOpArgumentMapping);
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/core/compat/op_utils.h"
namespace phi {
KernelSignature EyeOpArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature(
"eye", {}, {"num_rows", "num_columns", "dtype"}, {"Out"});
}
} // namespace phi
PD_REGISTER_ARG_MAPPING_FN(eye, phi::EyeOpArgumentMapping);
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/core/compat/op_utils.h"
namespace phi {
KernelSignature PixelShuffleOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature(
"pixel_shuffle", {"X"}, {"upscale_factor", "data_format"}, {"Out"});
}
KernelSignature PixelShuffleGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature("pixel_shuffle_grad",
{GradVarName("Out")},
{"upscale_factor", "data_format"},
{GradVarName("X")});
}
} // namespace phi
PD_REGISTER_ARG_MAPPING_FN(pixel_shuffle, phi::PixelShuffleOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(pixel_shuffle_grad,
phi::PixelShuffleGradOpArgumentMapping);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册