/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #pragma once #include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/math/selected_rows_functor.h" #include "paddle/fluid/platform/transform.h" namespace paddle { namespace operators { using framework::Tensor; using platform::Transform; template class ClipFunctor { public: explicit ClipFunctor(const T min, const T max) : min_(min), max_(max) {} HOSTDEVICE T operator()(const T& x) const { if (x < min_) return min_; else if (x > max_) return max_; else return x; } private: T min_; T max_; }; template class ClipGradFunctor { public: explicit ClipGradFunctor(const T min, const T max) : min_(min), max_(max) {} HOSTDEVICE T operator()(const T& x, const T& y) const { return (y > min_ && y < max_) ? x : 0; } private: T min_; T max_; }; template class ClipKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { auto max = static_cast(context.Attr("max")); Tensor max_cpu; if (context.HasInput("Max")) { auto* max_t = context.Input("Max"); auto* max_data = max_t->data(); if (platform::is_gpu_place(max_t->place())) { TensorCopySync(*max_t, platform::CPUPlace(), &max_cpu); max_data = max_cpu.data(); } max = max_data[0]; } max = static_cast(max); auto min = context.Attr("min"); Tensor min_cpu; if (context.HasInput("Min")) { auto* min_t = context.Input("Min"); auto* min_data = min_t->data(); if (platform::is_gpu_place(min_t->place())) { TensorCopySync(*min_t, platform::CPUPlace(), &min_cpu); min_data = min_cpu.data(); } min = min_data[0]; } min = static_cast(min); PADDLE_ENFORCE_LT(min, max, platform::errors::InvalidArgument( "max should be greater than min. " "But received min = %f, max = %f", min, max)); auto* x_var = context.InputVar("X"); if (x_var->IsType()) { auto* x = context.Input("X"); auto* out = context.Output("Out"); T* out_data = out->mutable_data(context.GetPlace()); const T* x_data = x->data(); int64_t numel = x->numel(); Transform trans; trans(context.template device_context(), x_data, x_data + numel, out_data, ClipFunctor(min, max)); } else if (x_var->IsType()) { auto* x = context.Input("X"); auto* out = context.Output("Out"); PADDLE_ENFORCE_NE( x, out, platform::errors::InvalidArgument( "Inplace clip is not allowed when x is SelectedRows")); math::scatter::MergeAdd merge_func; merge_func(context.template device_context(), *x, out); auto* out_tensor = out->mutable_value(); auto* out_data = out_tensor->data(); int64_t numel = out_tensor->numel(); Transform trans; trans(context.template device_context(), out_data, out_data + numel, out_data, ClipFunctor(min, max)); } else { PADDLE_THROW("ClipOp only supports LoDTensor and SelectedRows"); } } }; template class ClipGradKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { auto max = static_cast(context.Attr("max")); Tensor max_cpu; if (context.HasInput("Max")) { auto* max_t = context.Input("Max"); auto* max_data = max_t->data(); if (platform::is_gpu_place(max_t->place())) { TensorCopySync(*max_t, platform::CPUPlace(), &max_cpu); max_data = max_cpu.data(); } max = max_data[0]; } max = static_cast(max); auto min = context.Attr("min"); Tensor min_cpu; if (context.HasInput("Min")) { auto* min_t = context.Input("Min"); auto* min_data = min_t->data(); if (platform::is_gpu_place(min_t->place())) { TensorCopySync(*min_t, platform::CPUPlace(), &min_cpu); min_data = min_cpu.data(); } min = min_data[0]; } min = static_cast(min); auto* d_out = context.Input(framework::GradVarName("Out")); auto* d_x = context.Output(framework::GradVarName("X")); if (d_x != nullptr) { auto* x = context.Input("X"); int64_t numel = d_out->numel(); auto* d_x_data = d_x->mutable_data(context.GetPlace()); const T* d_out_data = d_out->data(); const T* x_data = x->data(); Transform trans; trans(context.template device_context(), d_out_data, d_out_data + numel, x_data, d_x_data, ClipGradFunctor(min, max)); } } }; } // namespace operators } // namespace paddle