diff --git a/paddle/fluid/framework/unused_var_check.cc b/paddle/fluid/framework/unused_var_check.cc index 0f8465ab8948e425ec48d10052643699e3c10ce7..f8ace3e85a643e8166da2b2e6f35a8097761b8cd 100644 --- a/paddle/fluid/framework/unused_var_check.cc +++ b/paddle/fluid/framework/unused_var_check.cc @@ -75,6 +75,7 @@ static const std::unordered_set &GetOpWithUnusedVarAllowSet() { "data_norm_grad", // 0 "update_loss_scaling", // 0 "fused_embedding_eltwise_layernorm", // 0 + "trunc_grad", // 1 }); return *allow_set; } diff --git a/paddle/fluid/operators/trunc_op.cc b/paddle/fluid/operators/trunc_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..2b79e2152b2f3414c3e3b7794e8c07c00a2aee00 --- /dev/null +++ b/paddle/fluid/operators/trunc_op.cc @@ -0,0 +1,89 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/trunc_op.h" + +namespace paddle { +namespace operators { + +class TruncOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext *ctx) const override { + OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "trunc"); + OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "trunc"); + auto input_dims = ctx->GetInputDim("X"); + ctx->SetOutputDim("Out", input_dims); + ctx->ShareLoD("X", /*->*/ "Out"); + } +}; + +class TruncOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", "(Tensor), The input tensor of trunc op."); + AddOutput("Out", "(Tensor), The output tensor of trunc op."); + AddComment(R"DOC( +Trunc Operator. +Returns a new tensor with the truncated integer values of input. +$$out = trunc(x)$$ +)DOC"); + } +}; + +class TruncGradOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext *ctx) const override { + OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")), "Input", + framework::GradVarName("Out"), "TruncGrad"); + OP_INOUT_CHECK(ctx->HasOutput(framework::GradVarName("X")), "Output", + framework::GradVarName("X"), "TruncGrad"); + + auto dout_dims = ctx->GetInputDim(framework::GradVarName("Out")); + ctx->SetOutputDim(framework::GradVarName("X"), dout_dims); + } +}; + +template +class TruncGradOpMaker : public framework::SingleGradOpMaker { + public: + using framework::SingleGradOpMaker::SingleGradOpMaker; + + void Apply(GradOpPtr retv) const override { + retv->SetType("trunc_grad"); + retv->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); + retv->SetAttrMap(this->Attrs()); + retv->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OPERATOR(trunc, ops::TruncOp, ops::TruncOpMaker, + ops::TruncGradOpMaker, + ops::TruncGradOpMaker); + +REGISTER_OPERATOR(trunc_grad, ops::TruncGradOp); + +REGISTER_OP_CPU_KERNEL(trunc, ops::TruncKernel, ops::TruncKernel, + ops::TruncKernel, ops::TruncKernel); + +REGISTER_OP_CPU_KERNEL(trunc_grad, ops::TruncGradKernel, + ops::TruncGradKernel, ops::TruncGradKernel, + ops::TruncGradKernel); diff --git a/paddle/fluid/operators/trunc_op.cu b/paddle/fluid/operators/trunc_op.cu new file mode 100644 index 0000000000000000000000000000000000000000..a284e0ea6e393910c35f11a64039e6b58f2f67a2 --- /dev/null +++ b/paddle/fluid/operators/trunc_op.cu @@ -0,0 +1,115 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/trunc_op.h" +#include "paddle/fluid/platform/cuda_primitives.h" +#include "paddle/fluid/platform/gpu_info.h" + +namespace paddle { +namespace operators { + +using platform::PADDLE_CUDA_NUM_THREADS; + +template +class TruncFunctor { + public: + __device__ TruncFunctor(const T x) : x_(x) {} + __device__ T operator()() { return trunc(x_); } + + public: + const T x_; +}; + +template <> +class TruncFunctor { + public: + __device__ TruncFunctor(const int x) : x_(x) {} + __device__ int operator()() { return x_; } + + public: + const int x_; +}; + +template <> +class TruncFunctor { + public: + __device__ TruncFunctor(const int64_t x) : x_(x) {} + __device__ int64_t operator()() { return x_; } + + public: + const int64_t x_; +}; + +template +__global__ void Trunc(const T* x, T* out, int64_t N) { + CUDA_KERNEL_LOOP(index, N) { + TruncFunctor functor(x[index]); + out[index] = functor(); + } +} + +template +__global__ void TruncGrad(T* dx, int64_t N) { + CUDA_KERNEL_LOOP(index, N) { dx[index] = static_cast(0.0); } +} + +template +class TruncCUDAKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto* x = context.Input("X"); + auto* out = context.Output("Out"); + + const auto* x_data = x->data(); + auto* out_data = out->mutable_data(context.GetPlace()); + + int64_t numel = x->numel(); + + int theads = PADDLE_CUDA_NUM_THREADS; + int blocks = (numel + theads - 1) / theads; + + Trunc<<>>(x_data, out_data, numel); + } +}; + +template +class TruncCUDAGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto* dout = context.Input(framework::GradVarName("Out")); + auto* dx = context.Output(framework::GradVarName("X")); + + const auto* dout_data = dout->data(); + auto* dx_data = dx->mutable_data(context.GetPlace()); + + int64_t numel = dout->numel(); + + int theads = PADDLE_CUDA_NUM_THREADS; + int blocks = (numel + theads - 1) / theads; + + TruncGrad<<>>(dx_data, numel); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_CUDA_KERNEL(trunc, ops::TruncCUDAKernel, + ops::TruncCUDAKernel, ops::TruncCUDAKernel, + ops::TruncCUDAKernel); + +REGISTER_OP_CUDA_KERNEL(trunc_grad, ops::TruncCUDAGradKernel, + ops::TruncCUDAGradKernel, + ops::TruncCUDAGradKernel, + ops::TruncCUDAGradKernel); diff --git a/paddle/fluid/operators/trunc_op.h b/paddle/fluid/operators/trunc_op.h new file mode 100644 index 0000000000000000000000000000000000000000..0f788eae5249c57b92c7558451eca641a6840a41 --- /dev/null +++ b/paddle/fluid/operators/trunc_op.h @@ -0,0 +1,55 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/operator.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; + +template +class TruncKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + const Tensor* x = context.Input("X"); + Tensor* out = context.Output("Out"); + + size_t numel = x->numel(); + const T* x_data = x->data(); + T* out_data = out->mutable_data(context.GetPlace()); + + for (size_t i = 0; i < numel; i++) { + out_data[i] = trunc(x_data[i]); + } + } +}; + +template +class TruncGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto* dx = context.Output(framework::GradVarName("X")); + T* dx_data = dx->mutable_data(context.GetPlace()); + + int numel = dx->numel(); + memset(dx_data, 0.0, numel * sizeof(T)); + } +}; + +} // namespace operators +} // namespace paddle diff --git a/python/paddle/__init__.py b/python/paddle/__init__.py index 738de4e393d7782631fd1d2ca256e6afcb23fa9a..b5315a5d19ac7ef85f9c138218ba679082c39335 100755 --- a/python/paddle/__init__.py +++ b/python/paddle/__init__.py @@ -205,6 +205,7 @@ from .tensor.math import isnan # noqa: F401 from .tensor.math import prod # noqa: F401 from .tensor.math import broadcast_shape # noqa: F401 from .tensor.math import conj # noqa: F401 +from .tensor.math import trunc # noqa: F401 from .tensor.math import digamma # noqa: F401 from .tensor.math import neg # noqa: F401 from .tensor.math import lgamma # noqa: F401 @@ -490,6 +491,7 @@ __all__ = [ # noqa 'log10', 'concat', 'check_shape', + 'trunc' 'digamma', 'standard_normal' ] diff --git a/python/paddle/fluid/tests/unittests/test_trunc_op.py b/python/paddle/fluid/tests/unittests/test_trunc_op.py new file mode 100644 index 0000000000000000000000000000000000000000..51844071138c70f74834f829ddd329f978aa1bb1 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_trunc_op.py @@ -0,0 +1,88 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +from op_test import OpTest +import paddle +import paddle.fluid.core as core +import paddle.fluid as fluid +from paddle.fluid import Program, program_guard + +paddle.enable_static() + + +class TestTruncOp(OpTest): + def setUp(self): + self.op_type = "trunc" + self.dtype = np.float64 + np.random.seed(2021) + self.inputs = {'X': np.random.random((20, 20)).astype(self.dtype)} + self.outputs = {'Out': (np.trunc(self.inputs['X']))} + + def init_dtype_type(self): + self.dtype = np.float64 + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(['X'], 'Out', numeric_grad_delta=1e-5) + + +class TestFloatTruncOp(TestTruncOp): + def init_dtype_type(self): + self.dtype = np.float32 + + +class TestIntTruncOp(TestTruncOp): + def init_dtype_type(self): + self.dtype = np.int32 + + +class TestTruncAPI(unittest.TestCase): + def setUp(self): + self.shape = [20, 20] + self.x = np.random.random((20, 20)).astype(np.float32) + self.place = paddle.CPUPlace() + + def test_api_static(self): + paddle.enable_static() + with paddle.static.program_guard(paddle.static.Program()): + x = paddle.fluid.data('X', self.shape) + out = paddle.trunc(x) + exe = paddle.static.Executor(self.place) + res = exe.run(feed={'X': self.x}, fetch_list=[out]) + out_ref = np.trunc(self.x) + for out in res: + self.assertEqual(np.allclose(out, out_ref, rtol=1e-08), True) + + def test_api_dygraph(self): + paddle.disable_static(self.place) + x_tensor = paddle.to_tensor(self.x) + out = paddle.trunc(x_tensor) + out_ref = np.trunc(self.x) + self.assertEqual(np.allclose(out.numpy(), out_ref, rtol=1e-08), True) + paddle.enable_static() + + def test_errors(self): + with paddle.static.program_guard(paddle.static.Program()): + x = paddle.fluid.data('X', [20, 20], 'bool') + self.assertRaises(TypeError, paddle.trunc, x) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/tensor/__init__.py b/python/paddle/tensor/__init__.py index 8c83b1786b01e058b84211136750b271833dbb79..206aa62adfb779ce89598d5cb84a576bbbe12492 100755 --- a/python/paddle/tensor/__init__.py +++ b/python/paddle/tensor/__init__.py @@ -162,6 +162,7 @@ from .math import all # noqa: F401 from .math import any # noqa: F401 from .math import broadcast_shape # noqa: F401 from .math import conj # noqa: F401 +from .math import trunc # noqa: F401 from .math import digamma # noqa: F401 from .math import neg # noqa: F401 from .math import lgamma # noqa: F401 @@ -349,5 +350,6 @@ tensor_method_func = [ #noqa 'shape', 'real', 'imag', + 'trunc' 'digamma' ] diff --git a/python/paddle/tensor/math.py b/python/paddle/tensor/math.py index a9e24949aae2b18e4dcba30bc98613278ef52977..2ffb8d9302ccd87e59f7eac80f7e4c9a77e90462 100755 --- a/python/paddle/tensor/math.py +++ b/python/paddle/tensor/math.py @@ -857,6 +857,50 @@ def add_n(inputs, name=None): return out +def trunc(input, name=None): + ''' + This API is used to returns a new tensor with the truncated integer values of input. + + Args: + input (Tensor): The input tensor, it's data type should be int32, int64, float32, float64. + name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. + + Returns: + Tensor: The output Tensor of trunc. + + Examples: + .. code-block:: python + + import paddle + + input = paddle.rand([2,2],'float32') + print(input) + # Tensor(shape=[2, 2], dtype=float32, place=CUDAPlace(0), stop_gradient=True, + # [[0.02331470, 0.42374918], + # [0.79647720, 0.74970269]]) + + output = paddle.trunc(input) + print(output) + # Tensor(shape=[2, 2], dtype=float32, place=CUDAPlace(0), stop_gradient=True, + # [[0., 0.], + # [0., 0.]])) + ''' + if in_dygraph_mode(): + return core.ops.trunc(input) + else: + inputs = {"X": input} + attrs = {} + + helper = LayerHelper("trunc", **locals()) + check_variable_and_dtype(input, 'X', ['int32', 'int64', 'float32', 'float64'], 'trunc') + out = helper.create_variable_for_type_inference(dtype=input.dtype) + + helper.append_op( + type="trunc", inputs=inputs, attrs=attrs, outputs={"Out": out}) + return out + + + def mm(input, mat2, name=None): """