未验证 提交 fb473067 编写于 作者: F From00 提交者: GitHub

Move Abs OP to pten (#39492)

* Move Abs op to pten

* Fix NPU compilation error

* Fix CI error

* Use LaunchSameDimsElementwiseCudaKernel in pten
上级 536a55fa
......@@ -12,12 +12,11 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/abs_op.h"
#include <memory>
#include <string>
#include <unordered_map>
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
#ifdef PADDLE_WITH_MKLDNN
#include "paddle/fluid/platform/mkldnn_helper.h"
#endif
......@@ -108,7 +107,7 @@ class AbsDoubleGradMaker : public framework::SingleGradOpMaker<T> {
protected:
void Apply(GradOpPtr<T> op) const override {
op->SetType("abs_grad_grad");
op->SetType("abs_double_grad");
// input1: x
op->SetInput("X", this->Input("X"));
// input2: ddx
......@@ -159,37 +158,4 @@ REGISTER_OPERATOR(abs_grad, ops::AbsGradOp,
ops::AbsDoubleGradMaker<paddle::framework::OpDesc>,
ops::AbsDoubleGradMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(abs_grad_grad, ops::AbsDoubleGradOp);
REGISTER_OP_CPU_KERNEL(
abs, ops::AbsKernel<paddle::platform::CPUDeviceContext, float>,
ops::AbsKernel<paddle::platform::CPUDeviceContext, double>,
ops::AbsKernel<paddle::platform::CPUDeviceContext, int>,
ops::AbsKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::AbsKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<float>>,
ops::AbsKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<double>>);
REGISTER_OP_CPU_KERNEL(
abs_grad, ops::AbsGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::AbsGradKernel<paddle::platform::CPUDeviceContext, double>,
ops::AbsGradKernel<paddle::platform::CPUDeviceContext, int>,
ops::AbsGradKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::AbsGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<float>>,
ops::AbsGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<double>>);
REGISTER_OP_CPU_KERNEL(
abs_grad_grad,
ops::AbsDoubleGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::AbsDoubleGradKernel<paddle::platform::CPUDeviceContext, double>,
ops::AbsDoubleGradKernel<paddle::platform::CPUDeviceContext, int>,
ops::AbsDoubleGradKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::AbsDoubleGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::float16>,
ops::AbsDoubleGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<float>>,
ops::AbsDoubleGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<double>>);
REGISTER_OPERATOR(abs_double_grad, ops::AbsDoubleGradOp);
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/abs_op.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h"
#include "paddle/fluid/platform/float16.h"
namespace paddle {
namespace operators {
template <typename T, typename Enable = void>
struct CudaAbsFunctor;
template <typename T>
struct CudaAbsFunctor<T, math::Complex<T, math::Real<T>>> {
__device__ __forceinline__ math::Real<T> operator()(const T x) const {
return abs(x);
}
};
template <typename T>
struct CudaAbsFunctor<T, math::NoComplex<T, math::Real<T>>> {
__device__ __forceinline__ T operator()(const T x) const {
return std::abs(x);
}
};
template <typename T>
class AbsKernel<platform::CUDADeviceContext, T>
: public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
const Tensor* x = context.Input<Tensor>("X");
Tensor* out = context.Output<Tensor>("Out");
out->mutable_data<math::Real<T>>(context.GetPlace());
auto& dev_ctx =
context.template device_context<platform::CUDADeviceContext>();
std::vector<const framework::Tensor*> ins = {x};
std::vector<framework::Tensor*> outs = {out};
auto functor = CudaAbsFunctor<T>();
paddle::operators::LaunchSameDimsElementwiseCudaKernel<math::Real<T>>(
dev_ctx, ins, &outs, functor);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(
abs, ops::AbsKernel<plat::CUDADeviceContext, float>,
ops::AbsKernel<plat::CUDADeviceContext, double>,
ops::AbsKernel<plat::CUDADeviceContext, int>,
ops::AbsKernel<plat::CUDADeviceContext, int64_t>,
ops::AbsKernel<plat::CUDADeviceContext, plat::float16>,
ops::AbsKernel<plat::CUDADeviceContext, plat::complex<float>>,
ops::AbsKernel<plat::CUDADeviceContext, plat::complex<double>>);
REGISTER_OP_CUDA_KERNEL(
abs_grad, ops::AbsGradKernel<plat::CUDADeviceContext, float>,
ops::AbsGradKernel<plat::CUDADeviceContext, double>,
ops::AbsGradKernel<plat::CUDADeviceContext, int>,
ops::AbsGradKernel<plat::CUDADeviceContext, int64_t>,
ops::AbsGradKernel<plat::CUDADeviceContext, plat::float16>,
ops::AbsGradKernel<plat::CUDADeviceContext, plat::complex<float>>,
ops::AbsGradKernel<plat::CUDADeviceContext, plat::complex<double>>);
REGISTER_OP_CUDA_KERNEL(
abs_grad_grad, ops::AbsDoubleGradKernel<plat::CUDADeviceContext, float>,
ops::AbsDoubleGradKernel<plat::CUDADeviceContext, double>,
ops::AbsDoubleGradKernel<plat::CUDADeviceContext, int>,
ops::AbsDoubleGradKernel<plat::CUDADeviceContext, int64_t>,
ops::AbsDoubleGradKernel<plat::CUDADeviceContext, plat::float16>,
ops::AbsDoubleGradKernel<plat::CUDADeviceContext, plat::complex<float>>,
ops::AbsDoubleGradKernel<plat::CUDADeviceContext, plat::complex<double>>);
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/operators/math/complex_functors.h"
#include "paddle/fluid/platform/for_range.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename DeviceContext, typename T>
class AbsKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
const Tensor* x = context.Input<Tensor>("X");
Tensor* out = context.Output<Tensor>("Out");
auto numel = x->numel();
auto* x_data = x->data<T>();
auto* out_data = out->mutable_data<math::Real<T>>(
context.GetPlace(), size_t(x->numel() * sizeof(math::Real<T>)));
auto& dev_ctx = context.template device_context<DeviceContext>();
platform::ForRange<DeviceContext> for_range(dev_ctx, numel);
math::AbsFunctor<T> functor(x_data, out_data, numel);
for_range(functor);
}
};
template <typename DeviceContext, typename T>
class AbsGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const {
const framework::Tensor* d_out =
ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
const framework::Tensor* x = ctx.Input<framework::Tensor>("X");
framework::Tensor* d_x =
ctx.Output<framework::Tensor>(framework::GradVarName("X"));
auto numel = d_out->numel();
auto* dout_data = d_out->data<math::Real<T>>();
auto* x_data = x->data<T>();
auto* dx_data = d_x->mutable_data<T>(
ctx.GetPlace(), static_cast<size_t>(numel * sizeof(T)));
auto& dev_ctx = ctx.template device_context<DeviceContext>();
platform::ForRange<DeviceContext> for_range(dev_ctx, numel);
math::AbsGradFunctor<T> functor(dout_data, x_data, dx_data, numel);
for_range(functor);
}
};
template <typename DeviceContext, typename T>
class AbsDoubleGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const {
const framework::Tensor* ddx = ctx.Input<framework::Tensor>("DDX");
const framework::Tensor* x = ctx.Input<framework::Tensor>("X");
framework::Tensor* ddout = ctx.Output<framework::Tensor>("DDOut");
auto numel = ddx->numel();
auto* ddx_data = ddx->data<T>();
auto* x_data = x->data<T>();
auto* ddout_data = ddout->mutable_data<T>(
ctx.GetPlace(), static_cast<size_t>(numel * sizeof(T)));
auto& dev_ctx = ctx.template device_context<DeviceContext>();
platform::ForRange<DeviceContext> for_range(dev_ctx, numel);
math::AbsGradGradFunctor<T> functor(ddx_data, x_data, ddout_data, numel);
for_range(functor);
}
};
} // namespace operators
} // namespace paddle
......@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the Licnse. */
#include "paddle/fluid/operators/abs_op.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/device/npu/npu_op_runner.h"
namespace paddle {
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/pten/core/dense_tensor.h"
#include "paddle/pten/core/device_context.h"
namespace pten {
template <typename T, typename Context>
void AbsGradKernel(const Context& ctx,
const DenseTensor& x,
const DenseTensor& dout,
DenseTensor* dx);
template <typename T, typename Context>
void AbsDoubleGradKernel(const Context& ctx,
const DenseTensor& x,
const DenseTensor& ddx,
DenseTensor* ddout);
} // namespace pten
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/pten/core/dense_tensor.h"
#include "paddle/pten/core/device_context.h"
namespace pten {
template <typename T, typename Context>
void AbsKernel(const Context& ctx, const DenseTensor& x, DenseTensor* out);
} // namespace pten
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/math/complex_functors.h"
#include "paddle/pten/common/complex.h"
#include "paddle/pten/core/kernel_registry.h"
#include "paddle/pten/kernels/impl/abs_grad_kernel_impl.h"
using pten::dtype::complex;
PT_REGISTER_KERNEL(abs_grad,
CPU,
ALL_LAYOUT,
pten::AbsGradKernel,
float,
double,
int,
int64_t,
complex<float>,
complex<double>) {}
PT_REGISTER_KERNEL(abs_double_grad,
CPU,
ALL_LAYOUT,
pten::AbsDoubleGradKernel,
float,
double,
int,
int64_t,
complex<float>,
complex<double>) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/pten/kernels/abs_kernel.h"
#include "paddle/fluid/operators/math/complex_functors.h"
#include "paddle/fluid/platform/for_range.h"
#include "paddle/pten/backends/cpu/cpu_context.h"
#include "paddle/pten/common/complex.h"
#include "paddle/pten/core/kernel_registry.h"
namespace pten {
template <typename T, typename Context>
void AbsKernel(const Context& ctx, const DenseTensor& x, DenseTensor* out) {
auto numel = x.numel();
auto* x_data = x.data<T>();
ctx.template Alloc<paddle::operators::math::Real<T>>(
out, size_t(x.numel() * sizeof(paddle::operators::math::Real<T>)));
auto* out_data = out->data<paddle::operators::math::Real<T>>();
paddle::platform::ForRange<Context> for_range(ctx, numel);
paddle::operators::math::AbsFunctor<T> functor(x_data, out_data, numel);
for_range(functor);
}
} // namespace pten
PT_REGISTER_KERNEL(abs,
CPU,
ALL_LAYOUT,
pten::AbsKernel,
float,
double,
int,
int64_t,
pten::dtype::complex<float>,
pten::dtype::complex<double>) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/pten/common/complex.h"
#include "paddle/pten/common/float16.h"
#include "paddle/pten/core/kernel_registry.h"
#include "paddle/pten/kernels/abs_grad_kernel.h"
#include "paddle/pten/kernels/impl/abs_grad_kernel_impl.h"
using pten::dtype::complex;
PT_REGISTER_KERNEL(abs_grad,
GPU,
ALL_LAYOUT,
pten::AbsGradKernel,
float,
double,
int,
int64_t,
pten::dtype::float16,
complex<float>,
complex<double>) {}
PT_REGISTER_KERNEL(abs_double_grad,
GPU,
ALL_LAYOUT,
pten::AbsDoubleGradKernel,
float,
double,
int,
int64_t,
pten::dtype::float16,
complex<float>,
complex<double>) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <algorithm>
#include <vector>
#include "paddle/fluid/operators/math/complex_functors.h"
#include "paddle/pten/backends/gpu/gpu_context.h"
#include "paddle/pten/core/dense_tensor.h"
#include "paddle/pten/core/kernel_registry.h"
#include "paddle/pten/kernels/abs_kernel.h"
#include "paddle/pten/kernels/funcs/elementwise_base.h"
namespace pten {
template <typename T, typename Enable = void>
struct CudaAbsFunctor;
template <typename T>
struct CudaAbsFunctor<
T,
paddle::operators::math::Complex<T, paddle::operators::math::Real<T>>> {
__device__ __forceinline__ paddle::operators::math::Real<T> operator()(
const T x) const {
return abs(x);
}
};
template <typename T>
struct CudaAbsFunctor<
T,
paddle::operators::math::NoComplex<T, paddle::operators::math::Real<T>>> {
__device__ __forceinline__ T operator()(const T x) const {
return std::abs(x);
}
};
template <typename T, typename Context>
void AbsKernel(const Context& ctx, const DenseTensor& x, DenseTensor* out) {
ctx.template Alloc<paddle::operators::math::Real<T>>(out);
std::vector<const DenseTensor*> ins = {&x};
std::vector<DenseTensor*> outs = {out};
auto functor = CudaAbsFunctor<T>();
funcs::LaunchSameDimsElementwiseCudaKernel<paddle::operators::math::Real<T>>(
ctx, ins, &outs, functor);
}
} // namespace pten
PT_REGISTER_KERNEL(abs,
GPU,
ALL_LAYOUT,
pten::AbsKernel,
float,
double,
int,
int64_t,
pten::dtype::float16,
pten::dtype::complex<float>,
pten::dtype::complex<double>) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/operators/math/complex_functors.h"
#include "paddle/fluid/platform/for_range.h"
#include "paddle/pten/kernels/abs_grad_kernel.h"
namespace pten {
template <typename T, typename Context>
void AbsGradKernel(const Context& ctx,
const DenseTensor& x,
const DenseTensor& dout,
DenseTensor* dx) {
auto numel = dout.numel();
auto* dout_data = dout.data<paddle::operators::math::Real<T>>();
auto* x_data = x.data<T>();
ctx.template Alloc<T>(dx, static_cast<size_t>(numel * sizeof(T)));
auto* dx_data = dx->data<T>();
paddle::platform::ForRange<Context> for_range(ctx, numel);
paddle::operators::math::AbsGradFunctor<T> functor(
dout_data, x_data, dx_data, numel);
for_range(functor);
}
template <typename T, typename Context>
void AbsDoubleGradKernel(const Context& ctx,
const DenseTensor& x,
const DenseTensor& ddx,
DenseTensor* ddout) {
auto numel = ddx.numel();
auto* ddx_data = ddx.data<T>();
auto* x_data = x.data<T>();
ctx.template Alloc<T>(ddout, static_cast<size_t>(numel * sizeof(T)));
auto* ddout_data = ddout->data<T>();
paddle::platform::ForRange<Context> for_range(ctx, numel);
paddle::operators::math::AbsGradGradFunctor<T> functor(
ddx_data, x_data, ddout_data, numel);
for_range(functor);
}
} // namespace pten
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/pten/core/compat/op_utils.h"
namespace pten {
KernelSignature AbsOpArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature("abs", {"X"}, {}, {"Out"});
}
KernelSignature AbsGradOpArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature(
"abs_grad", {"X", GradVarName("Out")}, {}, {GradVarName("X")});
}
KernelSignature AbsDoubleGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature("abs_double_grad", {"X", "DDX"}, {}, {"DDOut"});
}
} // namespace pten
PT_REGISTER_ARG_MAPPING_FN(abs, pten::AbsOpArgumentMapping);
PT_REGISTER_ARG_MAPPING_FN(abs_grad, pten::AbsGradOpArgumentMapping);
PT_REGISTER_ARG_MAPPING_FN(abs_double_grad,
pten::AbsDoubleGradOpArgumentMapping);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册