diff --git a/paddle/operators/clip_op.cc b/paddle/operators/clip_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..8eea843889fbf19689f6baa819f2c228684bfa5b --- /dev/null +++ b/paddle/operators/clip_op.cc @@ -0,0 +1,73 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include "paddle/operators/clip_op.h" + +namespace paddle { +namespace operators { + +using framework::Tensor; + +class ClipOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(const framework::InferShapeContext &ctx) const override { + auto x_dims = ctx.Input("X")->dims(); + auto max = GetAttr("max"); + auto min = GetAttr("min"); + PADDLE_ENFORCE_LT(min, max, "max should be greater than min."); + ctx.Output("Out")->Resize(x_dims); + } +}; + +class ClipOpMaker : public framework::OpProtoAndCheckerMaker { + public: + ClipOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("X", "The input of clip op"); + AddOutput("Out", "The output of clip op"); + AddComment(R"DOC( +Clip Operator. +)DOC"); + AddAttr("min", "min value to be clipped."); + AddAttr("max", "max value to be clipped."); + } +}; + +class ClipOpGrad : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(const framework::InferShapeContext &ctx) const override { + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "Input(X) should not be null"); + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Out")), + "Input(Out@GRAD) should not be null"); + auto x_dims = ctx.Input("X")->dims(); + auto *x_grad = ctx.Output(framework::GradVarName("X")); + + x_grad->Resize(x_dims); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP(clip, ops::ClipOp, ops::ClipOpMaker, clip_grad, ops::ClipOpGrad); +REGISTER_OP_CPU_KERNEL(clip, + ops::ClipKernel); +REGISTER_OP_CPU_KERNEL(clip_grad, ops::ClipGradKernel); diff --git a/paddle/operators/clip_op.cu b/paddle/operators/clip_op.cu new file mode 100644 index 0000000000000000000000000000000000000000..51941deecec95a8096e3500c5b9c1e54db00912e --- /dev/null +++ b/paddle/operators/clip_op.cu @@ -0,0 +1,67 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#define EIGEN_USE_GPU +#include "paddle/operators/clip_op.h" + +#define CUDA_1D_KERNEL_LOOP(i, n) \ + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < n; \ + i += blockDim.x * gridDim.x) + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; + +template +__global__ void ClipGradientKernel(const int N, const T min, const T max, + const T* Y, const T* dY, T* dX) { + CUDA_1D_KERNEL_LOOP(i, N) { dX[i] = dY[i] * (Y[i] > min && Y[i] < max); } +} + +template +class ClipGradientOpCUDAKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto max = context.op().GetAttr("max"); + auto min = context.op().GetAttr("min"); + auto* d_out = context.Input(framework::GradVarName("Out")); + auto* d_x = context.Output(framework::GradVarName("X")); + auto* x = context.Output("X"); + auto dims = d_x->dims(); + size_t count = 1; + for (int i = 0; i < dims.size(); ++i) { + count *= dims[i]; + } + auto d_x_data = d_x->mutable_data(context.GetPlace()); + auto d_out_data = d_out->data(); + auto x_data = x->data(); + + int N = d_x->dims()[0]; + int D = d_x->dims()[1]; + int block = 512; + int grid = (N * D + block - 1) / block; + + ClipGradientKernel<<>>(count, min, max, x_data, d_out_data, + d_x_data); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_GPU_KERNEL(clip, + ops::ClipKernel); +REGISTER_OP_GPU_KERNEL(clip_grad, ops::ClipGradientOpCUDAKernel); diff --git a/paddle/operators/clip_op.h b/paddle/operators/clip_op.h new file mode 100644 index 0000000000000000000000000000000000000000..b9a2c61f72e988a366b2b078362c3928db7ace10 --- /dev/null +++ b/paddle/operators/clip_op.h @@ -0,0 +1,70 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#pragma once + +#include "paddle/framework/eigen.h" +#include "paddle/framework/op_registry.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; + +template +using EigenTensor = framework::EigenTensor; + +template +class ClipKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto max = context.op().GetAttr("max"); + auto min = context.op().GetAttr("min"); + auto* x = context.Input("X"); + auto* out = context.Output("Out"); + out->mutable_data(context.GetPlace()); + auto x_tensor = EigenTensor::From(*x); + auto out_tensor = EigenTensor::From(*out); + auto place = context.GetEigenDevice(); + out_tensor.device(place) = x_tensor.cwiseMin(max).cwiseMax(min); + } +}; + +template +class ClipGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto max = context.op().GetAttr("max"); + auto min = context.op().GetAttr("min"); + auto* d_out = context.Input(framework::GradVarName("Out")); + auto* d_x = context.Output(framework::GradVarName("X")); + auto* x = context.Output("X"); + auto dims = d_x->dims(); + size_t count = 1; + for (int i = 0; i < dims.size(); ++i) { + count *= dims[i]; + } + + auto d_x_data = d_x->mutable_data(context.GetPlace()); + auto d_out_data = d_out->data(); + auto x_data = x->data(); + for (int i = 0; i < count; ++i) { + d_x_data[i] = d_out_data[i] * (x_data[i] > min && x_data[i] < max); + } + } +}; + +} // namespace operators +} // namespace paddle diff --git a/paddle/pybind/pybind.cc b/paddle/pybind/pybind.cc index 6896422617be0a3c73dc7b0d7cc1113075fa2f4b..2200bc2af2a8d9e574e8d68802a9f1dcc8c1d83c 100644 --- a/paddle/pybind/pybind.cc +++ b/paddle/pybind/pybind.cc @@ -48,6 +48,7 @@ USE_NO_KERNEL_OP(identity); USE_OP(minus); USE_CPU_ONLY_OP(gather); USE_CPU_ONLY_OP(scatter); +USE_OP(clip); namespace paddle { namespace framework { diff --git a/python/paddle/v2/framework/tests/op_test_util.py b/python/paddle/v2/framework/tests/op_test_util.py index 3bc05a0feccbbd3d5e7852d85bd3dc8edaccfd07..bb2d0a3f3a475c2d351b4d4be5b0ca237de632a9 100644 --- a/python/paddle/v2/framework/tests/op_test_util.py +++ b/python/paddle/v2/framework/tests/op_test_util.py @@ -34,8 +34,10 @@ class OpTestMeta(type): arr = self.inputs[in_name] var.set_dims(arr.shape) var.set(arr, place) + print "var: %s" % in_name else: kwargs[in_name] = "@EMPTY@" + print "var: %s=EMPTY" % in_name for out_name in Operator.get_op_output_names(self.type): if not hasattr(self, "outputs"): @@ -46,6 +48,7 @@ class OpTestMeta(type): (out_name)) kwargs[out_name] = out_name scope.new_var(out_name).get_tensor() + print "var: %s" % out_name for attr_name in Operator.get_op_attr_names(self.type): if hasattr(self, "attrs") and attr_name in self.attrs: @@ -62,7 +65,9 @@ class OpTestMeta(type): for out_name in Operator.get_op_output_names(self.type): actual = numpy.array(scope.find_var(out_name).get_tensor()) + print "actual: %s" % actual expect = self.outputs[out_name] + print "expect: %s" % expect self.assertTrue( numpy.allclose( actual, expect, atol=1e-05), diff --git a/python/paddle/v2/framework/tests/test_clip_op.py b/python/paddle/v2/framework/tests/test_clip_op.py new file mode 100644 index 0000000000000000000000000000000000000000..5dd09801917458f888826bb14f5c9e12a07aead4 --- /dev/null +++ b/python/paddle/v2/framework/tests/test_clip_op.py @@ -0,0 +1,39 @@ +import unittest +import numpy as np +from paddle.v2.framework.op import Operator +from gradient_checker import GradientChecker +from op_test_util import OpTestMeta + + +class TestClipOp(unittest.TestCase): + __metaclass__ = OpTestMeta + + def setUp(self): + input = np.random.random((16, 16)).astype("float32") + print "input: %s" % input + self.type = "clip" + self.inputs = {'X': input, } + self.attrs = {} + self.attrs['min'] = 0.1 + self.attrs['max'] = 0.9 + self.outputs = { + 'Out': np.clip(self.inputs['X'], self.attrs['min'], + self.attrs['max']) + } + + +class TestClipGradOp(GradientChecker): + def setUp(self): + self.op = Operator(type="clip", X="X", Out="Out", min=0.1, max=0.9) + self.inputs = {'X': np.random.random((16, 16)).astype("float32"), } + + def test_normal(self): + self.check_grad( + self.op, self.inputs, set(["X"]), "Out", max_relative_error=0.5) + + def test_cpu_gpu_compare(self): + self.compare_grad(self.op, self.inputs) + + +if __name__ == '__main__': + unittest.main()