未验证 提交 b7bcd0f6 编写于 作者: A Aurelius84 提交者: GitHub

[Phi] Migrate lable_smooth_op into Phi (#39796)

* [Phi] Migrate lable_smooth_op into Phi

* fix PT->PD
上级 24f55aed
......@@ -2040,7 +2040,7 @@ void OperatorWithKernel::BuildPtenKernelContext(
(i == 0 ? 0 : pt_kernel_context->InputRangeAt(i - 1).second);
// deal with optional here
if ((it == ctx.inputs.end()) &&
if ((it == ctx.inputs.end() || it->second.size() == 0) &&
(input_defs[i].type_index ==
std::type_index(typeid(paddle::optional<const phi::DenseTensor&>)))) {
pt_kernel_context->EmplaceBackInputWithoutSetRange(nullptr);
......
......@@ -12,9 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/label_smooth_op.h"
#include <string>
#include "paddle/fluid/framework/op_registry.h"
namespace paddle {
namespace framework {
......@@ -152,11 +151,3 @@ REGISTER_OPERATOR(label_smooth, ops::LabelSmoothOp, ops::LabelSmoothOpMaker,
ops::LabelSmoothGradMaker<paddle::framework::OpDesc>,
ops::LabelSmoothGradMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(label_smooth_grad, ops::LabelSmoothGradOp);
REGISTER_OP_CPU_KERNEL(
label_smooth,
ops::LabelSmoothKernel<paddle::platform::CPUDeviceContext, float>,
ops::LabelSmoothKernel<paddle::platform::CPUDeviceContext, double>);
REGISTER_OP_CPU_KERNEL(
label_smooth_grad,
ops::LabelSmoothGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::LabelSmoothGradKernel<paddle::platform::CPUDeviceContext, double>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h"
#include "paddle/fluid/operators/label_smooth_op.h"
namespace paddle {
namespace operators {
template <typename T>
struct LabelSmoothFunctor {
T epsilon;
T label_dim;
__forceinline__ LabelSmoothFunctor(float epsilon_data, int label_dim_data) {
epsilon = static_cast<T>(epsilon_data);
label_dim = static_cast<T>(label_dim_data);
}
__device__ __forceinline__ T operator()(const T x) const {
return (static_cast<T>(1 - epsilon) * x +
static_cast<T>(epsilon / label_dim));
}
};
template <typename T>
struct LabelSmoothGradFunctor {
T epsilon;
__forceinline__ LabelSmoothGradFunctor(float epsilon_data) {
epsilon = static_cast<T>(epsilon_data);
}
__device__ __forceinline__ T operator()(const T x) const {
return static_cast<T>(1 - epsilon) * x;
}
};
template <typename T>
__global__ void LabelSmoothRunDistKernel(const int N, const float epsilon,
const int dist_numel, const T* src,
const T* dist_data, T* dst) {
CUDA_KERNEL_LOOP(idx, N) {
int dist_idx = idx % dist_numel;
dst[idx] = static_cast<T>(1 - epsilon) * src[idx] +
static_cast<T>(epsilon) * dist_data[dist_idx];
}
}
template <typename DeviceContext, typename T>
class LabelSmoothGPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const {
auto* out_t = ctx.Output<framework::LoDTensor>("Out");
auto* in_t = ctx.Input<framework::LoDTensor>("X");
auto* dist_t = ctx.Input<framework::Tensor>("PriorDist");
auto label_dim = in_t->dims()[in_t->dims().size() - 1];
auto epsilon = ctx.Attr<float>("epsilon");
auto& dev = *ctx.template device_context<DeviceContext>().eigen_device();
auto size_prob = in_t->numel();
const T* in_data = in_t->data<T>();
T* out_data = out_t->mutable_data<T>(ctx.GetPlace());
int threads = 512;
int grid = (size_prob + threads - 1) / threads;
auto stream = ctx.cuda_device_context().stream();
if (dist_t) {
auto dist_numel = dist_t->numel();
const T* dist_data = dist_t->data<T>();
LabelSmoothRunDistKernel<T><<<grid, threads, 0, stream>>>(
size_prob, epsilon, dist_numel, in_data, dist_data, out_data);
} else {
auto& dev_ctx =
ctx.template device_context<platform::CUDADeviceContext>();
std::vector<const framework::Tensor*> ins = {in_t};
std::vector<framework::Tensor*> outs = {out_t};
auto functor = LabelSmoothFunctor<T>(epsilon, label_dim);
paddle::operators::LaunchSameDimsElementwiseCudaKernel<T>(dev_ctx, ins,
&outs, functor);
}
}
};
template <typename DeviceContext, typename T>
class LabelSmoothGradGPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const {
auto* d_out_t = ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
auto* d_in_t = ctx.Output<framework::Tensor>(framework::GradVarName("X"));
d_in_t->mutable_data<T>(ctx.GetPlace());
auto epsilon = ctx.Attr<float>("epsilon");
auto& dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
std::vector<const framework::Tensor*> ins = {d_out_t};
std::vector<framework::Tensor*> outs = {d_in_t};
auto functor = LabelSmoothGradFunctor<T>(epsilon);
paddle::operators::LaunchSameDimsElementwiseCudaKernel<T>(dev_ctx, ins,
&outs, functor);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
label_smooth,
ops::LabelSmoothGPUKernel<paddle::platform::CUDADeviceContext, float>,
ops::LabelSmoothGPUKernel<paddle::platform::CUDADeviceContext, double>);
REGISTER_OP_CUDA_KERNEL(
label_smooth_grad,
ops::LabelSmoothGradGPUKernel<paddle::platform::CUDADeviceContext, float>,
ops::LabelSmoothGradGPUKernel<paddle::platform::CUDADeviceContext, double>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
namespace paddle {
namespace operators {
template <typename DeviceContext, typename T>
class LabelSmoothKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const {
auto* out_t = ctx.Output<framework::LoDTensor>("Out");
auto* in_t = ctx.Input<framework::LoDTensor>("X");
auto* dist_t = ctx.Input<framework::Tensor>("PriorDist");
auto label_dim = in_t->dims()[in_t->dims().size() - 1];
out_t->mutable_data<T>(ctx.GetPlace());
if (label_dim != 0) {
auto epsilon = ctx.Attr<float>("epsilon");
auto out = framework::EigenVector<T>::Flatten(*out_t);
auto in = framework::EigenVector<T>::Flatten(*in_t);
auto& dev = *ctx.template device_context<DeviceContext>().eigen_device();
if (dist_t) {
auto dist = framework::EigenVector<T>::Flatten(*dist_t);
out.device(dev) = static_cast<T>(1 - epsilon) * in +
static_cast<T>(epsilon) *
dist.broadcast(Eigen::DSizes<int, 1>(
in_t->numel() / label_dim));
} else {
out.device(dev) = static_cast<T>(1 - epsilon) * in +
static_cast<T>(epsilon / label_dim);
}
}
}
};
template <typename DeviceContext, typename T>
class LabelSmoothGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const {
auto* d_out_t = ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
auto* d_in_t = ctx.Output<framework::Tensor>(framework::GradVarName("X"));
d_in_t->mutable_data<T>(ctx.GetPlace());
auto d_out_dim = d_out_t->dims()[d_out_t->dims().size() - 1];
if (d_out_dim != 0) {
auto d_out = framework::EigenVector<T>::Flatten(*d_out_t);
auto d_in = framework::EigenVector<T>::Flatten(*d_in_t);
auto epsilon = ctx.Attr<float>("epsilon");
auto& dev = *ctx.template device_context<DeviceContext>().eigen_device();
d_in.device(dev) = static_cast<T>(1 - epsilon) * d_out;
}
}
};
} // namespace operators
} // namespace paddle
......@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/label_smooth_op.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/device/npu/npu_op_runner.h"
namespace paddle {
......
......@@ -13,7 +13,6 @@ See the License for the specific language governing permissions and
limitations under the License. */
#ifdef PADDLE_WITH_XPU
#include "paddle/fluid/operators/label_smooth_op.h"
#include "paddle/fluid/framework/op_registry.h"
namespace paddle {
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/label_smooth_grad_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/eigen/common.h"
namespace phi {
template <typename T, typename Context>
void LabelSmoothGradKernel(const Context& ctx,
const DenseTensor& out_grad,
float epsilon,
DenseTensor* label_grad) {
ctx.template Alloc<T>(label_grad);
auto d_out_dim = out_grad.dims()[out_grad.dims().size() - 1];
if (d_out_dim != 0) {
auto d_out = EigenVector<T>::Flatten(out_grad);
auto d_in = EigenVector<T>::Flatten(*label_grad);
auto& dev = *ctx.eigen_device();
d_in.device(dev) = static_cast<T>(1 - epsilon) * d_out;
}
}
} // namespace phi
PD_REGISTER_KERNEL(label_smooth_grad,
CPU,
ALL_LAYOUT,
phi::LabelSmoothGradKernel,
float,
double) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/label_smooth_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/eigen/common.h"
namespace phi {
template <typename T, typename Context>
void LabelSmoothKernel(const Context& ctx,
const DenseTensor& label,
paddle::optional<const DenseTensor&> prior_dist,
float epsilon,
DenseTensor* out) {
auto label_dim = label.dims()[label.dims().size() - 1];
ctx.template Alloc<T>(out);
auto& dev = *ctx.eigen_device();
if (label_dim != 0) {
auto eigen_out = EigenVector<T>::Flatten(*out);
auto eigen_in = EigenVector<T>::Flatten(label);
if (prior_dist.is_initialized()) {
auto dist = EigenVector<T>::Flatten(*prior_dist.get_ptr());
eigen_out.device(dev) =
static_cast<T>(1 - epsilon) * eigen_in +
static_cast<T>(epsilon) *
dist.broadcast(Eigen::DSizes<int, 1>(label.numel() / label_dim));
} else {
eigen_out.device(dev) = static_cast<T>(1 - epsilon) * eigen_in +
static_cast<T>(epsilon / label_dim);
}
}
}
} // namespace phi
PD_REGISTER_KERNEL(
label_smooth, CPU, ALL_LAYOUT, phi::LabelSmoothKernel, float, double) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/label_smooth_grad_kernel.h"
namespace phi {
template <typename T>
struct LabelSmoothGradFunctor {
T epsilon;
__forceinline__ LabelSmoothGradFunctor(float epsilon_data) {
epsilon = static_cast<T>(epsilon_data);
}
__device__ __forceinline__ T operator()(const T x) const {
return static_cast<T>(1 - epsilon) * x;
}
};
template <typename T, typename Context>
void LabelSmoothGradKernel(const Context& ctx,
const DenseTensor& out_grad,
float epsilon,
DenseTensor* label_grad) {
ctx.template Alloc<T>(label_grad);
std::vector<const DenseTensor*> ins = {&out_grad};
std::vector<DenseTensor*> outs = {label_grad};
auto functor = LabelSmoothGradFunctor<T>(epsilon);
paddle::operators::LaunchSameDimsElementwiseCudaKernel<T>(
ctx, ins, &outs, functor);
}
} // namespace phi
PD_REGISTER_KERNEL(label_smooth_grad,
GPU,
ALL_LAYOUT,
phi::LabelSmoothGradKernel,
float,
double) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <vector>
#include "paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/label_smooth_kernel.h"
namespace phi {
template <typename T>
struct LabelSmoothFunctor {
T epsilon;
T label_dim;
__forceinline__ LabelSmoothFunctor(float epsilon_data, int label_dim_data) {
epsilon = static_cast<T>(epsilon_data);
label_dim = static_cast<T>(label_dim_data);
}
__device__ __forceinline__ T operator()(const T x) const {
return (static_cast<T>(1 - epsilon) * x +
static_cast<T>(epsilon / label_dim));
}
};
template <typename T>
__global__ void LabelSmoothRunDistKernel(const int N,
const float epsilon,
const int dist_numel,
const T* src,
const T* dist_data,
T* dst) {
CUDA_KERNEL_LOOP(idx, N) {
int dist_idx = idx % dist_numel;
dst[idx] = static_cast<T>(1 - epsilon) * src[idx] +
static_cast<T>(epsilon) * dist_data[dist_idx];
}
}
template <typename T, typename Context>
void LabelSmoothKernel(const Context& ctx,
const DenseTensor& label,
paddle::optional<const DenseTensor&> prior_dist,
float epsilon,
DenseTensor* out) {
auto label_dim = label.dims()[label.dims().size() - 1];
auto size_prob = label.numel();
const T* in_data = label.data<T>();
T* out_data = ctx.template Alloc<T>(out);
if (prior_dist.get_ptr()) {
int threads = 512;
int grid = (size_prob + threads - 1) / threads;
auto stream = ctx.stream();
const auto* dist_t = prior_dist.get_ptr();
auto dist_numel = dist_t->numel();
const T* dist_data = dist_t->data<T>();
LabelSmoothRunDistKernel<T><<<grid, threads, 0, stream>>>(
size_prob, epsilon, dist_numel, in_data, dist_data, out_data);
} else {
std::vector<const DenseTensor*> ins = {&label};
std::vector<DenseTensor*> outs = {out};
auto functor = LabelSmoothFunctor<T>(epsilon, label_dim);
paddle::operators::LaunchSameDimsElementwiseCudaKernel<T>(
ctx, ins, &outs, functor);
}
}
} // namespace phi
PD_REGISTER_KERNEL(
label_smooth, GPU, ALL_LAYOUT, phi::LabelSmoothKernel, float, double) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/device_context.h"
namespace phi {
template <typename T, typename Context>
void LabelSmoothGradKernel(const Context& ctx,
const DenseTensor& out_grad,
float epsilon,
DenseTensor* label_grad);
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/device_context.h"
#include "paddle/utils/optional.h"
namespace phi {
template <typename T, typename Context>
void LabelSmoothKernel(const Context& ctx,
const DenseTensor& label,
paddle::optional<const DenseTensor&> prior_dist,
float epsilon,
DenseTensor* out);
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/core/compat/op_utils.h"
namespace phi {
KernelSignature LabelSmoothOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature(
"label_smooth", {"X", "PriorDist"}, {"epsilon"}, {"Out"});
}
KernelSignature LabelSmoothGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature("label_smooth_grad",
{GradVarName("Out")},
{"epsilon"},
{GradVarName("X")});
}
} // namespace phi
PD_REGISTER_ARG_MAPPING_FN(label_smooth, phi::LabelSmoothOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(label_smooth_grad,
phi::LabelSmoothGradOpArgumentMapping);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册