未验证 提交 c0658045 编写于 作者: W wuyefeilin 提交者: GitHub

[phi] Move clip op to phi (#40602)

* move clip op to phi

* fix as review

* update hierarchical_sigmoid_kernel.cc

* update selected_rows

* update clip_kernel.cu

* fix as review
上级 e59a693e
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/clip_op.h"
#include <memory>
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/unary.h"
namespace paddle {
namespace operators {
......@@ -23,15 +25,6 @@ namespace operators {
class ClipOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "clip");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "clip");
auto x_dims = ctx->GetInputDim("X");
ctx->SetOutputDim("Out", x_dims);
ctx->ShareLoD("X", /*->*/ "Out");
}
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
auto input_data_type =
......@@ -176,23 +169,15 @@ class ClipDoubleGradOpMaker : public framework::SingleGradOpMaker<T> {
} // namespace paddle
namespace ops = paddle::operators;
DECLARE_INFER_SHAPE_FUNCTOR(clip, ClipInferShapeFunctor,
PD_INFER_META(phi::UnchangedInferMeta));
REGISTER_OPERATOR(clip, ops::ClipOp, ops::ClipOpMaker<float>,
ops::ClipGradOpMaker<paddle::framework::OpDesc>,
ops::ClipGradOpMaker<paddle::imperative::OpBase>,
ops::ClipInplaceInferer);
ops::ClipInplaceInferer, ClipInferShapeFunctor);
REGISTER_OPERATOR(clip_grad, ops::ClipOpGrad, ops::ClipGradInplaceInferer,
ops::ClipDoubleGradOpMaker<paddle::framework::OpDesc>,
ops::ClipDoubleGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OP_CPU_KERNEL(
clip, ops::ClipKernel<paddle::platform::CPUDeviceContext, float>,
ops::ClipKernel<paddle::platform::CPUDeviceContext, double>,
ops::ClipKernel<paddle::platform::CPUDeviceContext, int>,
ops::ClipKernel<paddle::platform::CPUDeviceContext, int64_t>);
REGISTER_OP_CPU_KERNEL(
clip_grad, ops::ClipGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::ClipGradKernel<paddle::platform::CPUDeviceContext, double>,
ops::ClipGradKernel<paddle::platform::CPUDeviceContext, int>,
ops::ClipGradKernel<paddle::platform::CPUDeviceContext, int64_t>);
REGISTER_OP_VERSION(clip)
.AddCheckpoint(
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/clip_op.h"
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
clip, ops::ClipKernel<paddle::platform::CUDADeviceContext, float>,
ops::ClipKernel<paddle::platform::CUDADeviceContext, double>,
ops::ClipKernel<paddle::platform::CUDADeviceContext, int>,
ops::ClipKernel<paddle::platform::CUDADeviceContext, int64_t>,
ops::ClipKernel<paddle::platform::CUDADeviceContext,
paddle::platform::float16>);
REGISTER_OP_CUDA_KERNEL(
clip_grad, ops::ClipGradKernel<paddle::platform::CUDADeviceContext, float>,
ops::ClipGradKernel<paddle::platform::CUDADeviceContext, double>,
ops::ClipGradKernel<paddle::platform::CUDADeviceContext, int>,
ops::ClipGradKernel<paddle::platform::CUDADeviceContext, int64_t>,
ops::ClipGradKernel<paddle::platform::CUDADeviceContext,
paddle::platform::float16>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/selected_rows_functor.h"
#include "paddle/fluid/platform/transform.h"
#if defined(__NVCC__) || defined(__HIPCC__)
#include "paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h"
#endif
namespace paddle {
namespace operators {
using framework::Tensor;
using platform::Transform;
template <typename T>
class ClipFunctor {
public:
explicit ClipFunctor(const T min, const T max) : min_(min), max_(max) {}
HOSTDEVICE T operator()(const T x) const {
return x < min_ ? min_ : x > max_ ? max_ : x;
}
private:
T min_;
T max_;
};
template <typename T>
class ClipGradFunctor {
public:
explicit ClipGradFunctor(const T min, const T max) : min_(min), max_(max) {}
HOSTDEVICE T operator()(const T x, const T y) const {
return (y > min_ && y < max_) ? x : static_cast<T>(0);
}
private:
T min_;
T max_;
};
template <typename DeviceContext, typename T>
class ClipKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto max = static_cast<T>(context.Attr<float>("max"));
Tensor max_cpu;
if (context.HasInput("Max")) {
auto* max_t = context.Input<Tensor>("Max");
auto* max_data = max_t->data<T>();
if (platform::is_gpu_place(max_t->place())) {
paddle::framework::TensorCopySync(*max_t, platform::CPUPlace(),
&max_cpu);
max_data = max_cpu.data<T>();
}
max = max_data[0];
}
max = static_cast<T>(max);
auto min = static_cast<T>(context.Attr<float>("min"));
Tensor min_cpu;
if (context.HasInput("Min")) {
auto* min_t = context.Input<Tensor>("Min");
auto* min_data = min_t->data<T>();
if (platform::is_gpu_place(min_t->place())) {
paddle::framework::TensorCopySync(*min_t, platform::CPUPlace(),
&min_cpu);
min_data = min_cpu.data<T>();
}
min = min_data[0];
}
PADDLE_ENFORCE_LE(min, max,
platform::errors::InvalidArgument(
"max should be greater than or equal to min. "
"But received min = %f, max = %f",
static_cast<float>(min), static_cast<float>(max)));
auto* x_var = context.InputVar("X");
if (x_var->IsType<framework::LoDTensor>()) {
auto* x = context.Input<framework::LoDTensor>("X");
auto* out = context.Output<framework::LoDTensor>("Out");
T* out_data = out->mutable_data<T>(context.GetPlace());
const T* x_data = x->data<T>();
int64_t numel = x->numel();
if (platform::is_gpu_place(context.GetPlace())) {
#if defined(__NVCC__) || defined(__HIPCC__)
std::vector<const framework::Tensor*> ins = {x};
std::vector<framework::Tensor*> outs = {out};
auto functor = ClipFunctor<T>(min, max);
paddle::operators::LaunchSameDimsElementwiseCudaKernel<T>(
context.template device_context<platform::CUDADeviceContext>(), ins,
&outs, functor);
#endif
} else {
Transform<DeviceContext> trans;
trans(context.template device_context<DeviceContext>(), x_data,
x_data + numel, out_data, ClipFunctor<T>(min, max));
}
} else if (x_var->IsType<phi::SelectedRows>()) {
auto* x = context.Input<phi::SelectedRows>("X");
auto* out = context.Output<phi::SelectedRows>("Out");
PADDLE_ENFORCE_NE(x, out, platform::errors::InvalidArgument(
"Inplace clip is not allowed "
"when x is SelectedRows"));
math::scatter::MergeAdd<DeviceContext, T> merge_func;
merge_func(context.template device_context<DeviceContext>(), *x, out);
auto* out_tensor = out->mutable_value();
auto* out_data = out_tensor->data<T>();
int64_t numel = out_tensor->numel();
Transform<DeviceContext> trans;
trans(context.template device_context<DeviceContext>(), out_data,
out_data + numel, out_data, ClipFunctor<T>(min, max));
} else {
PADDLE_THROW(platform::errors::Unavailable(
"ClipOp only supports LoDTensor and SelectedRows."));
}
}
};
template <typename DeviceContext, typename T>
class ClipGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto max = static_cast<T>(context.Attr<float>("max"));
Tensor max_cpu;
if (context.HasInput("Max")) {
auto* max_t = context.Input<Tensor>("Max");
auto* max_data = max_t->data<T>();
if (platform::is_gpu_place(max_t->place())) {
paddle::framework::TensorCopySync(*max_t, platform::CPUPlace(),
&max_cpu);
max_data = max_cpu.data<T>();
}
max = max_data[0];
}
max = static_cast<T>(max);
auto min = static_cast<T>(context.Attr<float>("min"));
Tensor min_cpu;
if (context.HasInput("Min")) {
auto* min_t = context.Input<Tensor>("Min");
auto* min_data = min_t->data<T>();
if (platform::is_gpu_place(min_t->place())) {
paddle::framework::TensorCopySync(*min_t, platform::CPUPlace(),
&min_cpu);
min_data = min_cpu.data<T>();
}
min = min_data[0];
}
min = static_cast<T>(min);
auto* d_out =
context.Input<framework::LoDTensor>(framework::GradVarName("Out"));
auto* d_x =
context.Output<framework::LoDTensor>(framework::GradVarName("X"));
if (d_x != nullptr) {
auto* x = context.Input<framework::LoDTensor>("X");
#if defined(__NVCC__) || defined(__HIPCC__)
std::vector<const framework::Tensor*> ins = {d_out, x};
std::vector<framework::Tensor*> outs = {d_x};
auto functor = ClipGradFunctor<T>(min, max);
d_x->mutable_data<T>(context.GetPlace());
LaunchSameDimsElementwiseCudaKernel<T>(
context.template device_context<platform::CUDADeviceContext>(), ins,
&outs, functor);
#else
int64_t numel = d_out->numel();
auto* d_x_data = d_x->mutable_data<T>(context.GetPlace());
const T* d_out_data = d_out->data<T>();
const T* x_data = x->data<T>();
Transform<DeviceContext> trans;
trans(context.template device_context<DeviceContext>(), d_out_data,
d_out_data + numel, x_data, d_x_data, ClipGradFunctor<T>(min, max));
#endif
}
}
};
} // namespace operators
} // namespace paddle
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/clip_op.h"
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/device/npu/npu_op_runner.h"
namespace paddle {
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifdef PADDLE_WITH_XPU
#include "paddle/fluid/operators/clip_op.h"
#include "paddle/fluid/framework/op_registry.h"
namespace paddle {
......
......@@ -17,8 +17,8 @@ limitations under the License. */
#include <string>
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/operators/clip_op.h"
#include "paddle/fluid/platform/transform.h"
#include "paddle/phi/kernels/impl/clip_kernel_impl.h"
namespace paddle {
namespace operators {
......@@ -91,7 +91,7 @@ struct ClipAndFakeQuantFunctor<platform::CPUDeviceContext, T> {
T inv_s = inverse(s);
platform::Transform<platform::CPUDeviceContext> trans;
trans(ctx, in.data<T>(), in.data<T>() + in.numel(),
out->mutable_data<T>(ctx.GetPlace()), ClipFunctor<T>(-s, s));
out->mutable_data<T>(ctx.GetPlace()), phi::ClipFunctor<T>(-s, s));
auto out_e = framework::EigenVector<T>::Flatten(*out);
out_e.device(*ctx.eigen_device()) = (bin_cnt * inv_s * out_e).round();
}
......@@ -109,7 +109,7 @@ struct ClipAndFakeQuantDequantFunctor<platform::CPUDeviceContext, T> {
platform::Transform<platform::CPUDeviceContext> trans;
trans(ctx, in.data<T>(), in.data<T>() + in.numel(),
out->mutable_data<T>(ctx.GetPlace()), ClipFunctor<T>(-s, s));
out->mutable_data<T>(ctx.GetPlace()), phi::ClipFunctor<T>(-s, s));
auto out_e = framework::EigenVector<T>::Flatten(*out);
out_e.device(*ctx.eigen_device()) =
(bin_cnt * inv_s * out_e).round() * s / static_cast<T>(bin_cnt);
......@@ -144,7 +144,7 @@ struct ChannelClipAndFakeQuantFunctor<platform::CPUDeviceContext, T> {
auto* start = in_data + i * channel_size;
auto* end = in_data + (i + 1) * channel_size;
trans(ctx, start, end, out_data + i * channel_size,
ClipFunctor<T>(-s, s));
phi::ClipFunctor<T>(-s, s));
}
for (int64_t i = 0; i < channel; i++) {
T s = scale_data[i];
......@@ -163,7 +163,7 @@ struct ChannelClipAndFakeQuantFunctor<platform::CPUDeviceContext, T> {
auto* start = in_data + i * step_i + j * step_j;
auto* end = in_data + i * step_i + (j + 1) * step_j;
auto* cur_out_data = out_data + i * step_i + j * step_j;
trans(ctx, start, end, cur_out_data, ClipFunctor<T>(-s, s));
trans(ctx, start, end, cur_out_data, phi::ClipFunctor<T>(-s, s));
for (int k = 0; k < step_j; k++) {
cur_out_data[k] = std::round(bin_cnt * inv_s * cur_out_data[k]);
}
......@@ -200,7 +200,7 @@ struct ChannelClipFakeQuantDequantFunctor<platform::CPUDeviceContext, T> {
auto* start = in_data + i * channel_size;
auto* end = in_data + (i + 1) * channel_size;
trans(ctx, start, end, out_data + i * channel_size,
ClipFunctor<T>(-s, s));
phi::ClipFunctor<T>(-s, s));
}
for (int i = 0; i < channel; i++) {
T s = scale_data[i];
......@@ -220,7 +220,7 @@ struct ChannelClipFakeQuantDequantFunctor<platform::CPUDeviceContext, T> {
auto* start = in_data + i * step_i + j * step_j;
auto* end = in_data + i * step_i + (j + 1) * step_j;
auto* cur_out_data = out_data + i * step_i + j * step_j;
trans(ctx, start, end, cur_out_data, ClipFunctor<T>(-s, s));
trans(ctx, start, end, cur_out_data, phi::ClipFunctor<T>(-s, s));
for (int k = 0; k < step_j; k++) {
cur_out_data[k] = std::round(bin_cnt * inv_s * cur_out_data[k]) *
s / static_cast<T>(bin_cnt);
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/common/scalar.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/device_context.h"
namespace phi {
template <typename T, typename Context>
void ClipGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& out_grad,
const Scalar& min,
const Scalar& max,
DenseTensor* x_grad);
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/common/scalar.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/device_context.h"
#include "paddle/phi/core/selected_rows.h"
namespace phi {
template <typename T, typename Context>
void ClipKernel(const Context& dev_ctx,
const DenseTensor& x,
const Scalar& min,
const Scalar& max,
DenseTensor* out);
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/clip_grad_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/clip_grad_kernel_impl.h"
PD_REGISTER_KERNEL(clip_grad,
CPU,
ALL_LAYOUT,
phi::ClipGradKernel,
float,
double,
int,
int64_t) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/clip_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/clip_kernel_impl.h"
PD_REGISTER_KERNEL(
clip, CPU, ALL_LAYOUT, phi::ClipKernel, float, double, int, int64_t) {}
......@@ -14,7 +14,6 @@
#include "paddle/phi/kernels/hierarchical_sigmoid_kernel.h"
#include "paddle/fluid/operators/clip_op.h"
#include "paddle/fluid/operators/math/matrix_bit_code.h"
#include "paddle/fluid/platform/transform.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
......@@ -22,6 +21,7 @@
#include "paddle/phi/kernels/funcs/eigen/common.h"
#include "paddle/phi/kernels/funcs/eigen/eigen_function.h"
#include "paddle/phi/kernels/funcs/math_function_impl.h"
#include "paddle/phi/kernels/impl/clip_kernel_impl.h"
namespace phi {
......@@ -92,8 +92,7 @@ void HierarchicalSigmoidKernel(const Context& ctx,
pre_out_data,
pre_out_data + pre_out->numel(),
pre_out_data,
paddle::operators::ClipFunctor<T>(static_cast<T>(-40.0),
static_cast<T>(40.0)));
ClipFunctor<T>(static_cast<T>(-40.0), static_cast<T>(40.0)));
bit_code->Sum(*pre_out, out, static_cast<T>(-1));
// use softrelu to calculate cross entropy
pre_out_mat.device(place) = (static_cast<T>(1.0) + pre_out_mat.exp()).log();
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/clip_grad_kernel.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/clip_grad_kernel_impl.h"
PD_REGISTER_KERNEL(clip_grad,
GPU,
ALL_LAYOUT,
phi::ClipGradKernel,
float,
double,
int,
int64_t) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/clip_kernel.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/float16.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/clip_kernel_impl.h"
PD_REGISTER_KERNEL(clip,
GPU,
ALL_LAYOUT,
phi::ClipKernel,
float,
double,
int,
int64_t,
phi::dtype::float16) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/kernels/clip_kernel.h"
#include "paddle/phi/backends/all_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/transform.h"
#if defined(__NVCC__) || defined(__HIPCC__)
#include "paddle/phi/kernels/funcs/broadcast_function.h"
#endif
namespace phi {
template <typename T>
class ClipGradFunctor {
public:
explicit ClipGradFunctor(const T min, const T max) : min_(min), max_(max) {}
HOSTDEVICE T operator()(const T x, const T y) const {
return (y > min_ && y < max_) ? x : static_cast<T>(0);
}
private:
T min_;
T max_;
};
template <typename T, typename Context>
void ClipGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& out_grad,
const Scalar& min,
const Scalar& max,
DenseTensor* x_grad) {
auto max_ = max.to<T>();
auto min_ = min.to<T>();
#if defined(__NVCC__) || defined(__HIPCC__)
std::vector<const DenseTensor*> ins = {&out_grad, &x};
std::vector<DenseTensor*> outs = {x_grad};
auto functor = ClipGradFunctor<T>(min_, max_);
dev_ctx.template Alloc<T>(x_grad);
phi::funcs::ElementwiseKernel<T>(dev_ctx, ins, &outs, functor);
#else
int64_t numel = out_grad.numel();
auto* d_x_data = dev_ctx.template Alloc<T>(x_grad);
const T* d_out_data = out_grad.data<T>();
const T* x_data = x.data<T>();
paddle::platform::Transform<Context> trans;
trans(dev_ctx,
d_out_data,
d_out_data + numel,
x_data,
d_x_data,
ClipGradFunctor<T>(min_, max_));
#endif
}
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/kernels/clip_kernel.h"
#include "paddle/phi/backends/all_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/fluid/operators/math/selected_rows_functor.h"
#include "paddle/fluid/platform/transform.h"
#if defined(__NVCC__) || defined(__HIPCC__)
#include "paddle/phi/kernels/funcs/broadcast_function.h"
#endif
namespace phi {
template <typename T>
class ClipFunctor {
public:
explicit ClipFunctor(const T min, const T max) : min_(min), max_(max) {}
HOSTDEVICE T operator()(const T x) const {
return x < min_ ? min_ : x > max_ ? max_ : x;
}
private:
T min_;
T max_;
};
template <typename T, typename Context>
void ClipKernel(const Context& dev_ctx,
const DenseTensor& x,
const Scalar& min,
const Scalar& max,
DenseTensor* out) {
auto max_ = max.to<T>();
auto min_ = min.to<T>();
PADDLE_ENFORCE_LE(
min_,
max_,
errors::InvalidArgument("max should be greater than or equal to min. "
"But received min = %f, max = %f",
static_cast<float>(min_),
static_cast<float>(max_)));
T* out_data = dev_ctx.template Alloc<T>(out);
// const T* x_data = x->data<T>();
// int64_t numel = x->numel();
const T* x_data = x.data<T>();
int64_t numel = x.numel();
if (paddle::platform::is_gpu_place(dev_ctx.GetPlace())) {
#if defined(__NVCC__) || defined(__HIPCC__)
std::vector<const DenseTensor*> ins = {&x};
std::vector<DenseTensor*> outs = {out};
auto functor = ClipFunctor<T>(min_, max_);
phi::funcs::ElementwiseKernel<T>(dev_ctx, ins, &outs, functor);
#endif
} else {
paddle::platform::Transform<Context> trans;
trans(
dev_ctx, x_data, x_data + numel, out_data, ClipFunctor<T>(min_, max_));
}
}
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/operators/math/selected_rows_functor.h"
#include "paddle/phi/common/scalar.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/device_context.h"
#include "paddle/phi/core/selected_rows.h"
#include "paddle/phi/kernels/impl/clip_kernel_impl.h"
namespace phi {
namespace sr {
template <typename T, typename Context>
void ClipSparseKernel(const Context& dev_ctx,
const SelectedRows& x,
const Scalar& min,
const Scalar& max,
SelectedRows* out);
} // namespace sr
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/selected_rows/clip_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/selected_rows/impl/clip_kernel_impl.h"
PD_REGISTER_KERNEL(clip_sr,
CPU,
ALL_LAYOUT,
phi::sr::ClipSparseKernel,
float,
double,
int,
int64_t) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/selected_rows/clip_kernel.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/float16.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/selected_rows/impl/clip_kernel_impl.h"
PD_REGISTER_KERNEL(clip_sr,
GPU,
ALL_LAYOUT,
phi::sr::ClipSparseKernel,
float,
double,
int,
int64_t,
phi::dtype::float16) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/kernels/selected_rows/clip_kernel.h"
#include "paddle/fluid/operators/math/selected_rows_functor.h"
#include "paddle/phi/common/scalar.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/device_context.h"
#include "paddle/phi/core/selected_rows.h"
namespace phi {
namespace sr {
template <typename T, typename Context>
void ClipSparseKernel(const Context& dev_ctx,
const SelectedRows& x,
const Scalar& min,
const Scalar& max,
SelectedRows* out) {
auto max_ = max.to<T>();
auto min_ = min.to<T>();
PADDLE_ENFORCE_LE(
min_,
max_,
errors::InvalidArgument("max should be greater than or equal to min. "
"But received min = %f, max = %f",
static_cast<float>(min_),
static_cast<float>(max_)));
PADDLE_ENFORCE_NE(&x,
out,
errors::InvalidArgument("Inplace clip is not allowed "
"when x is SelectedRows"));
paddle::operators::math::scatter::MergeAdd<Context, T> merge_func;
merge_func(dev_ctx, x, out);
auto* out_tensor = out->mutable_value();
auto* out_data = out_tensor->data<T>();
int64_t numel = out_tensor->numel();
paddle::platform::Transform<Context> trans;
trans(dev_ctx,
out_data,
out_data + numel,
out_data,
ClipFunctor<T>(min_, max_));
}
} // namespace sr
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/core/compat/op_utils.h"
#include "paddle/utils/small_vector.h"
namespace phi {
KernelSignature ClipOpArgumentMapping(const ArgumentMappingContext& ctx) {
paddle::SmallVector<std::string> attr_names;
attr_names.emplace_back(ctx.HasInput("Min") ? "Min" : "min");
attr_names.emplace_back(ctx.HasInput("Max") ? "Max" : "max");
if (ctx.IsDenseTensorInput("X")) {
if (ctx.HasInput("Min")) {
if (ctx.HasInput("Max")) {
return KernelSignature("clip", {"X"}, {"Min", "Max"}, {"Out"});
} else {
return KernelSignature("clip", {"X"}, {"Min", "max"}, {"Out"});
}
} else {
if (ctx.HasInput("Max")) {
return KernelSignature("clip", {"X"}, {"min", "Max"}, {"Out"});
} else {
return KernelSignature("clip", {"X"}, {"min", "max"}, {"Out"});
}
}
} else if (ctx.IsSelectedRowsInput("X")) {
if (ctx.HasInput("Min")) {
if (ctx.HasInput("Max")) {
return KernelSignature("clip_sr", {"X"}, {"Min", "Max"}, {"Out"});
} else {
return KernelSignature("clip_sr", {"X"}, {"Min", "max"}, {"Out"});
}
} else {
if (ctx.HasInput("Max")) {
return KernelSignature("clip_sr", {"X"}, {"min", "Max"}, {"Out"});
} else {
return KernelSignature("clip_sr", {"X"}, {"min", "max"}, {"Out"});
}
}
}
return KernelSignature("unregistered", {}, {}, {});
}
KernelSignature ClipGradOpArgumentMapping(const ArgumentMappingContext& ctx) {
if (ctx.HasInput("Min")) {
if (ctx.HasInput("Max")) {
return KernelSignature("clip_grad",
{"X", GradVarName("Out")},
{"Min", "Max"},
{GradVarName("X")});
} else {
return KernelSignature("clip_grad",
{"X", GradVarName("Out")},
{"Min", "max"},
{GradVarName("X")});
}
} else {
if (ctx.HasInput("Max")) {
return KernelSignature("clip_grad",
{"X", GradVarName("Out")},
{"min", "Max"},
{GradVarName("X")});
} else {
return KernelSignature("clip_grad",
{"X", GradVarName("Out")},
{"min", "max"},
{GradVarName("X")});
}
}
}
} // namespace phi
PD_REGISTER_ARG_MAPPING_FN(clip, phi::ClipOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(clip_grad, phi::ClipGradOpArgumentMapping);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册