diff --git a/paddle/fluid/API.spec b/paddle/fluid/API.spec index 2d4c9ef18f0d249d54ef7b8188a1e7a8876c9bcb..5e6d0b97f4327c95f6e50b6af548e6c2ad85c73a 100755 --- a/paddle/fluid/API.spec +++ b/paddle/fluid/API.spec @@ -108,6 +108,7 @@ paddle.fluid.initializer.force_init_on_cpu (ArgSpec(args=[], varargs=None, keywo paddle.fluid.initializer.init_on_cpu (ArgSpec(args=[], varargs=None, keywords=None, defaults=None), ('document', 'eaa04fd68661a3af59abd0e19b3b6eda')) paddle.fluid.initializer.NumpyArrayInitializer ('paddle.fluid.initializer.NumpyArrayInitializer', ('document', '064f134a27c16372967d450f499762ab')) paddle.fluid.initializer.NumpyArrayInitializer.__init__ (ArgSpec(args=['self', 'value'], varargs=None, keywords=None, defaults=None), ('document', '6adf97f83acf6453d4a6a4b1070f3754')) +paddle.fluid.input.one_hot (ArgSpec(args=['input', 'depth', 'allow_out_of_range'], varargs=None, keywords=None, defaults=(False,)), ('document', 'c79292312a35b99ff2801a274b666358')) paddle.fluid.layers.fc (ArgSpec(args=['input', 'size', 'num_flatten_dims', 'param_attr', 'bias_attr', 'act', 'name'], varargs=None, keywords=None, defaults=(1, None, None, None, None)), ('document', '0dc8181f14a33f91fbae9385a9b3d9fd')) paddle.fluid.layers.center_loss (ArgSpec(args=['input', 'label', 'num_classes', 'alpha', 'param_attr', 'update_center'], varargs=None, keywords=None, defaults=(True,)), ('document', '7129819d94625c6104054e8187768589')) paddle.fluid.layers.embedding (ArgSpec(args=['input', 'size', 'is_sparse', 'is_distributed', 'padding_idx', 'param_attr', 'dtype'], varargs=None, keywords=None, defaults=(False, False, None, None, 'float32')), ('document', 'd8e405486a1e4e189b51d6ee28d67b1e')) diff --git a/paddle/fluid/operators/one_hot_v2_op.cc b/paddle/fluid/operators/one_hot_v2_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..7a75afca09cea13eb07749eb565ea880f8a5acf0 --- /dev/null +++ b/paddle/fluid/operators/one_hot_v2_op.cc @@ -0,0 +1,122 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/operators/one_hot_v2_op.h" +#include +#include +#include "paddle/fluid/framework/framework.pb.h" + +namespace paddle { +namespace operators { + +class OneHotV2Op : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true, + "Input(X) of OneHotOp should not be null."); + PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true, + "Output(Out) of OneHotOp should not be null."); + + auto x_dims = ctx->GetInputDim("X"); + PADDLE_ENFORCE_GE(x_dims.size(), 1, + "Rank of Input(X) should be at least 1."); + + int depth = ctx->Attrs().Get("depth"); + if (ctx->HasInput("depth_tensor")) { + depth = -1; + } + + auto out_dims_vec = framework::vectorize(x_dims); + out_dims_vec.push_back(depth); + auto out_dims = framework::make_ddim(out_dims_vec); + ctx->SetOutputDim("Out", out_dims); + ctx->ShareLoD("X", /* --> */ "Out"); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType(ctx.Input("X")->type(), + ctx.device_context()); + } + + framework::OpKernelType GetKernelTypeForVar( + const std::string& var_name, const Tensor& tensor, + const framework::OpKernelType& expected_kernel_type) const override { + if (var_name == "depth_tensor") { + return expected_kernel_type; + } + return framework::OpKernelType(expected_kernel_type.data_type_, + tensor.place(), tensor.layout()); + } +}; + +class OneHotV2OpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", + "(LoDTensor, LoDTensor) Input variable with rank at least 2. " + "The last dimension of X should be 1. Each value of X is an index " + "to indicate the position."); + AddInput("depth_tensor", "(Tensor, Tensor), Length of one-hot vector") + .AsDispensable(); + AddOutput("Out", + "(Tensor, Tensor) Output tensor with same rank as X. " + "The tensor consists of one-hot representations of values in X."); + + AddAttr("depth", + "A positive integer to specify the length of one-hot vector.") + .SetDefault(-1); + AddAttr("dtype", + "An integer to specify the data type of one-hot " + "vector. The default value is FP32.") + .SetDefault(paddle::framework::proto::VarType::FP32); + AddAttr("allow_out_of_range", + "If it is set true and the input data is out of range, " + "the output tensor will be filled zeros. The default value " + "is false.") + .SetDefault(false); + AddComment(R"DOC( +One Hot Operator. This operator creates the one-hot representations for input +index values. The following example will help to explain the function of this +operator: + +X is a LoDTensor: + X.lod = [[0, 1, 4]] + X.shape = [4] + X.data = [1, 1, 3, 0] + +set depth = 4 + +Out is a LoDTensor: + Out.lod = [[0, 1, 4]] + Out.shape = [4, 4] + Out.data = [[0., 1., 0., 0.], + [0., 1., 0., 0.], + [0., 0., 0., 1.], + [1., 0., 0., 0.]] +)DOC"); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OPERATOR(one_hot_v2, ops::OneHotV2Op, ops::OneHotV2OpMaker, + paddle::framework::EmptyGradOpMaker); +REGISTER_OP_CPU_KERNEL( + one_hot_v2, ops::OneHotV2Kernel, + ops::OneHotV2Kernel); diff --git a/paddle/fluid/operators/one_hot_v2_op.cu b/paddle/fluid/operators/one_hot_v2_op.cu new file mode 100644 index 0000000000000000000000000000000000000000..2366f1422244e34deaf9ba019eba1fde2620f7ad --- /dev/null +++ b/paddle/fluid/operators/one_hot_v2_op.cu @@ -0,0 +1,99 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/operators/one_hot_v2_op.h" +#include "paddle/fluid/platform/cuda_primitives.h" +#include "paddle/fluid/platform/gpu_info.h" + +namespace paddle { +namespace operators { +using platform::PADDLE_CUDA_NUM_THREADS; + +template +__global__ void FillOutputKernel(const InT* p_in_data, OutT* p_out_data, + const int64_t numel, const int depth) { + int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < numel && p_in_data[idx] >= 0 && p_in_data[idx] < depth) { + *(p_out_data + (idx * depth) + p_in_data[idx]) = 1.0; + } +} + +template +struct OneHotV2OpCUDAFunctor { + const framework::LoDTensor* in_; + framework::LoDTensor* out_; + const DeviceContext& ctx_; + int depth_; + + OneHotV2OpCUDAFunctor(const framework::LoDTensor* in, + framework::LoDTensor* out, int depth, + const DeviceContext& ctx) + : in_(in), out_(out), depth_(depth), ctx_(ctx) {} + + template + void apply() const { + auto* p_in_data = in_->data(); + auto numel = in_->numel(); + auto* p_out_data = out_->mutable_data(ctx_.GetPlace()); + auto stream = ctx_.stream(); + math::set_constant(ctx_, out_, 0.0); + + FillOutputKernel<<<(numel + PADDLE_CUDA_NUM_THREADS - 1) / + PADDLE_CUDA_NUM_THREADS, + PADDLE_CUDA_NUM_THREADS, 0, stream>>>( + p_in_data, p_out_data, numel, depth_); + } +}; + +using LoDTensor = framework::LoDTensor; +template +class OneHotV2CUDAKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto* in = context.Input("X"); + auto* out = context.Output("Out"); + + int depth = -1; + if (context.HasInput("depth_tensor")) { + auto* depth_tensor = context.Input("depth_tensor"); + if (platform::is_gpu_place(depth_tensor->place())) { + framework::Tensor temp; + TensorCopySync(*depth_tensor, platform::CPUPlace(), &temp); + depth = *temp.data(); + } else { + depth = *depth_tensor->data(); + } + + auto out_dims = out->dims(); + out_dims[out_dims.size() - 1] = depth; + out->Resize(out_dims); + } else { + depth = context.Attr("depth"); + } + framework::VisitDataType( + static_cast( + context.Attr("dtype")), + OneHotV2OpCUDAFunctor( + in, out, depth, context.template device_context())); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_CUDA_KERNEL( + one_hot_v2, + ops::OneHotV2CUDAKernel, + ops::OneHotV2CUDAKernel); diff --git a/paddle/fluid/operators/one_hot_v2_op.h b/paddle/fluid/operators/one_hot_v2_op.h new file mode 100644 index 0000000000000000000000000000000000000000..7cfe2d61d17f32bce87e6428add0a9654dcba778 --- /dev/null +++ b/paddle/fluid/operators/one_hot_v2_op.h @@ -0,0 +1,94 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/math/math_function.h" + +namespace paddle { +namespace operators { + +template +struct OneHotV2OpFunctor { + const framework::LoDTensor* in_; + framework::LoDTensor* out_; + int depth_; + const DeviceContext& ctx_; + bool allow_out_of_range_; + + OneHotV2OpFunctor(const framework::LoDTensor* in, framework::LoDTensor* out, + int depth, const DeviceContext& ctx, + bool allow_out_of_range = false) + : in_(in), + out_(out), + depth_(depth), + ctx_(ctx), + allow_out_of_range_(allow_out_of_range) {} + + template + void apply() const { + auto* p_in_data = in_->data(); + auto numel = in_->numel(); + auto* p_out_data = out_->mutable_data(ctx_.GetPlace()); + math::set_constant(ctx_, out_, 0.0); + + if (allow_out_of_range_) { + for (int i = 0; i < numel; ++i) { + if (p_in_data[i] >= 0 && p_in_data[i] < depth_) { + *(p_out_data + i * depth_ + p_in_data[i]) = 1.0; + } + } + } else { + for (int i = 0; i < numel; ++i) { + PADDLE_ENFORCE_GE(p_in_data[i], 0, + "Illegal index value, should be at least 0."); + PADDLE_ENFORCE_LT( + p_in_data[i], depth_, + "Illegal index value, should be less than depth (%d).", depth_); + *(p_out_data + i * depth_ + p_in_data[i]) = 1.0; + } + } + } +}; + +using LoDTensor = framework::LoDTensor; +using Tensor = framework::Tensor; +template +class OneHotV2Kernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto* in = context.Input("X"); + auto* out = context.Output("Out"); + int depth = context.Attr("depth"); + bool allow_out_of_range = context.Attr("allow_out_of_range"); + if (context.HasInput("depth_tensor")) { + auto* depth_tensor = context.Input("depth_tensor"); + auto* depth_data = depth_tensor->data(); + depth = depth_data[0]; + auto out_dims = out->dims(); + out_dims[out_dims.size() - 1] = depth; + out->Resize(out_dims); + } + + framework::VisitDataType( + static_cast( + context.Attr("dtype")), + OneHotV2OpFunctor( + in, out, depth, context.template device_context(), + allow_out_of_range)); + } +}; + +} // namespace operators +} // namespace paddle diff --git a/python/paddle/fluid/__init__.py b/python/paddle/fluid/__init__.py index 8a62b9a4b227e2b56daf8423fc440e2ac85d57b6..f2d139b2e66b05587f1b597347f44deddcf241da 100644 --- a/python/paddle/fluid/__init__.py +++ b/python/paddle/fluid/__init__.py @@ -62,6 +62,7 @@ from . import average from . import metrics from . import transpiler from . import incubate +from . import input from . import distribute_lookup_table from .param_attr import ParamAttr, WeightNormParamAttr from .data_feeder import DataFeeder @@ -92,6 +93,7 @@ __all__ = framework.__all__ + executor.__all__ + \ data_feed_desc.__all__ + compiler.__all__ + backward.__all__ + [ 'io', 'initializer', + 'input', 'layers', 'contrib', 'dygraph', diff --git a/python/paddle/fluid/input.py b/python/paddle/fluid/input.py new file mode 100644 index 0000000000000000000000000000000000000000..4169f646c0d9f0cdc132dbd31791b58549893fed --- /dev/null +++ b/python/paddle/fluid/input.py @@ -0,0 +1,67 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function +from .framework import Variable, in_dygraph_mode +from .layer_helper import LayerHelper + +__all__ = ['one_hot'] + + +def one_hot(input, depth, allow_out_of_range=False): + """ + This layer creates the one-hot representations for input indices. + + Args: + input(Variable): Input indices represent locations, which takes value 1.0 + in indices, while all other locations take value 0. + depth(scalar): An interger defining the depth of the one-hot dimension. + allow_out_of_range(bool): A bool value indicating whether the input + indices could be out of range [0, depth). When input indices are + out of range, exceptions is raised if allow_out_of_range is False, + or zero-filling representations is created if it is set True + + Returns: + Variable: The one-hot representations of input. + + Examples: + .. code-block:: python + + import paddle.fluid as fluid + label = fluid.layers.data(name="label", shape=[1], dtype="int64") + one_hot_label = fluid.input.one_hot(input=label, depth=10) + """ + helper = LayerHelper("one_hot_v2", **locals()) + + one_hot_out = helper.create_variable_for_type_inference(dtype='float32') + + if in_dygraph_mode(): + inputs = {'X': input} + attrs = {'depth': depth} + else: + if not isinstance(depth, Variable): + # user attribute + inputs = {'X': input} + attrs = {'depth': depth} + else: + depth.stop_gradient = True + inputs = {'X': input, 'depth_tensor': depth} + attrs = {} + helper.append_op( + type="one_hot_v2", + inputs=inputs, + attrs=attrs, + outputs={'Out': one_hot_out}, + stop_gradient=True) + return one_hot_out diff --git a/python/paddle/fluid/tests/unittests/test_one_hot_v2_op.py b/python/paddle/fluid/tests/unittests/test_one_hot_v2_op.py new file mode 100644 index 0000000000000000000000000000000000000000..85069b0203984ed41fe92c651294922642adcc4a --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_one_hot_v2_op.py @@ -0,0 +1,208 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +import math +from op_test import OpTest +import paddle.fluid as fluid +import paddle.fluid.core as core +import paddle.fluid.framework as framework +from paddle.fluid.framework import Program, program_guard + + +class TestOneHotOp(OpTest): + def setUp(self): + self.op_type = 'one_hot_v2' + depth = 10 + depth_np = np.array(10).astype('int32') + dimension = 12 + x_lod = [[4, 1, 3, 3]] + x = [np.random.randint(0, depth - 1) for i in range(sum(x_lod[0]))] + x = np.array(x).astype('int32').reshape([sum(x_lod[0])]) + + out = np.zeros(shape=(np.product(x.shape), depth)).astype('float32') + + for i in range(np.product(x.shape)): + out[i, x[i]] = 1.0 + + self.inputs = {'X': (x, x_lod), 'depth_tensor': depth_np} + self.attrs = {'dtype': int(core.VarDesc.VarType.FP32)} + self.outputs = {'Out': (out, x_lod)} + + def test_check_output(self): + self.check_output() + + +class TestOneHotOp_attr(OpTest): + def setUp(self): + self.op_type = 'one_hot_v2' + depth = 10 + dimension = 12 + x_lod = [[4, 1, 3, 3]] + x = [np.random.randint(0, depth - 1) for i in range(sum(x_lod[0]))] + x = np.array(x).astype('int32').reshape([sum(x_lod[0]), 1]) + + out = np.zeros(shape=(np.product(x.shape[:-1]), 1, + depth)).astype('float32') + + for i in range(np.product(x.shape)): + out[i, 0, x[i]] = 1.0 + + self.inputs = {'X': (x, x_lod)} + self.attrs = {'dtype': int(core.VarDesc.VarType.FP32), 'depth': depth} + self.outputs = {'Out': (out, x_lod)} + + def test_check_output(self): + self.check_output() + + +class TestOneHotOp_default_dtype(OpTest): + def setUp(self): + self.op_type = 'one_hot_v2' + depth = 10 + depth_np = np.array(10).astype('int32') + dimension = 12 + x_lod = [[4, 1, 3, 3]] + x = [np.random.randint(0, depth - 1) for i in range(sum(x_lod[0]))] + x = np.array(x).astype('int32').reshape([sum(x_lod[0])]) + + out = np.zeros(shape=(np.product(x.shape), depth)).astype('float32') + + for i in range(np.product(x.shape)): + out[i, x[i]] = 1.0 + + self.inputs = {'X': (x, x_lod), 'depth_tensor': depth_np} + self.attrs = {} + self.outputs = {'Out': (out, x_lod)} + + def test_check_output(self): + self.check_output() + + +class TestOneHotOp_default_dtype_attr(OpTest): + def setUp(self): + self.op_type = 'one_hot_v2' + depth = 10 + dimension = 12 + x_lod = [[4, 1, 3, 3]] + x = [np.random.randint(0, depth - 1) for i in range(sum(x_lod[0]))] + x = np.array(x).astype('int32').reshape([sum(x_lod[0]), 1]) + + out = np.zeros(shape=(np.product(x.shape[:-1]), 1, + depth)).astype('float32') + + for i in range(np.product(x.shape)): + out[i, 0, x[i]] = 1.0 + + self.inputs = {'X': (x, x_lod)} + self.attrs = {'depth': depth} + self.outputs = {'Out': (out, x_lod)} + + def test_check_output(self): + self.check_output() + + +class TestOneHotOp_out_of_range(OpTest): + def setUp(self): + self.op_type = 'one_hot_v2' + depth = 10 + x_lod = [[4, 1, 3, 3]] + x = [np.random.choice([-1, depth]) for i in range(sum(x_lod[0]))] + x = np.array(x).astype('int32').reshape([sum(x_lod[0])]) + + out = np.zeros(shape=(np.product(x.shape), depth)).astype('float32') + + self.inputs = {'X': (x, x_lod)} + self.attrs = {'depth': depth, 'allow_out_of_range': True} + self.outputs = {'Out': (out, x_lod)} + + def test_check_output(self): + self.check_output() + + +class TestOneHotOp_exception(OpTest): + def setUp(self): + self.op_type = 'one_hot_v2' + self.depth = 10 + self.place = core.CPUPlace() + self.dimension = 12 + self.x = core.LoDTensor() + x_lod = [[4, 1, 3, 3]] + data = [np.random.randint(11, 20) for i in range(sum(x_lod[0]))] + data = np.array(data).astype('int').reshape([sum(x_lod[0]), 1]) + self.x.set(data, self.place) + self.x.set_recursive_sequence_lengths(x_lod) + + def test_check_output(self): + program = Program() + with program_guard(program): + x = fluid.layers.data( + name='x', shape=[self.dimension], dtype='float32', lod_level=1) + block = program.current_block() + one_hot_out = block.create_var( + name="one_hot_out", + type=core.VarDesc.VarType.LOD_TENSOR, + dtype='float32') + block.append_op( + type='one_hot', + inputs={'X': x}, + attrs={'depth': self.depth}, + outputs={'Out': one_hot_out}) + exe = fluid.Executor(self.place) + + def run(): + exe.run(feed={'x': self.x}, + fetch_list=[one_hot_out], + return_numpy=False) + + self.assertRaises(core.EnforceNotMet, run) + + +class TestOneHotOpApi(unittest.TestCase): + def test_api(self): + depth = 10 + self._run(depth) + + def test_api_with_depthTensor(self): + depth = fluid.layers.assign(input=np.array([10], dtype=np.int32)) + self._run(depth) + + def test_api_with_dygraph(self): + depth = 10 + label = np.array([np.random.randint(0, depth - 1) + for i in range(6)]).reshape([6, 1]) + with fluid.dygraph.guard(): + one_hot_label = fluid.input.one_hot( + input=fluid.dygraph.to_variable(label), depth=depth) + + def _run(self, depth): + label = fluid.layers.data(name="label", shape=[1], dtype="int64") + one_hot_label = fluid.input.one_hot(input=label, depth=depth) + + place = fluid.CPUPlace() + label_data = np.array([np.random.randint(0, 10 - 1) + for i in range(6)]).reshape([6, 1]) + + exe = fluid.Executor(place) + exe.run(fluid.default_startup_program()) + ret = exe.run(feed={'label': label_data, }, + fetch_list=[one_hot_label], + return_numpy=False) + + +if __name__ == '__main__': + unittest.main()