From 84c12c6edc2bc0b6b410c4385e928c8d061ea18e Mon Sep 17 00:00:00 2001 From: Yang yaming Date: Fri, 26 Jan 2018 15:35:23 +0800 Subject: [PATCH] Add one_hot operator. (#7819) * Add one_hot operator. * Add more unit tests. --- paddle/operators/one_hot_op.cc | 95 +++++++++++++++ paddle/operators/one_hot_op.cu | 80 +++++++++++++ paddle/operators/one_hot_op.h | 68 +++++++++++ .../paddle/v2/fluid/tests/test_one_hot_op.py | 110 ++++++++++++++++++ 4 files changed, 353 insertions(+) create mode 100644 paddle/operators/one_hot_op.cc create mode 100644 paddle/operators/one_hot_op.cu create mode 100644 paddle/operators/one_hot_op.h create mode 100644 python/paddle/v2/fluid/tests/test_one_hot_op.py diff --git a/paddle/operators/one_hot_op.cc b/paddle/operators/one_hot_op.cc new file mode 100644 index 00000000000..e78b7468de4 --- /dev/null +++ b/paddle/operators/one_hot_op.cc @@ -0,0 +1,95 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/operators/one_hot_op.h" +#include "paddle/framework/framework.pb.h" + +namespace paddle { +namespace operators { + +class OneHotOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), + "Input(X) of OneHotOp should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("Out"), + "Output(Out) of OneHotOp should not be null."); + + auto x_dims = ctx->GetInputDim("X"); + PADDLE_ENFORCE_GE(x_dims.size(), 2, + "Rank of Input(X) should be at least 2."); + PADDLE_ENFORCE_GE(x_dims[x_dims.size() - 1], 1U, + "Last dimension of Input(X) should be 1."); + + int depth = ctx->Attrs().Get("depth"); + + PADDLE_ENFORCE_GT(depth, 0, "Should provide a positive depth (%d).", depth); + + framework::DDim out_dims(x_dims); + out_dims[out_dims.size() - 1] = depth; + ctx->SetOutputDim("Out", out_dims); + ctx->ShareLoD("X", /* --> */ "Out"); + } +}; + +class OneHotOpMaker : public framework::OpProtoAndCheckerMaker { + public: + OneHotOpMaker(OpProto* proto, OpAttrChecker* op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("X", + "(LoDTensor, LoDTensor) Input variable with rank at least 2. " + "The last dimension of X should be 1. Each value of X is an index " + "to indicate the position."); + AddOutput("Out", + "(Tensor, Tensor) Output tensor with same rank as X. " + "The tensor consists of one-hot representations of values in X."); + AddAttr("depth", + "A positive integer to specify the length of one-hot vector."); + AddAttr("dtype", + "An integer to specify the data type of one-hot " + "vector. The default value is FP32.") + .SetDefault(paddle::framework::proto::DataType::FP32); + AddComment(R"DOC( +One Hot Operator. This operator creates the one-hot representations for input +index values. The following example will help to explain the function of this +operator: + +X is a LoDTensor: + X.lod = [[0, 1, 4]] + X.shape = [4, 1] + X.data = [[1], [1], [3], [0]] + +set depth = 4 + +Out is a LoDTensor: + Out.lod = [[0, 1, 4]] + Out.shape = [4, 4] + Out.data = [[0., 1., 0., 0.], + [0., 1., 0., 0.], + [0., 0., 0., 1.], + [1., 0., 0., 0.]] +)DOC"); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OPERATOR(one_hot, ops::OneHotOp, ops::OneHotOpMaker, + paddle::framework::EmptyGradOpMaker); +REGISTER_OP_CPU_KERNEL( + one_hot, ops::OneHotKernel, + ops::OneHotKernel); diff --git a/paddle/operators/one_hot_op.cu b/paddle/operators/one_hot_op.cu new file mode 100644 index 00000000000..16f6d9433ea --- /dev/null +++ b/paddle/operators/one_hot_op.cu @@ -0,0 +1,80 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/operators/one_hot_op.h" +#include "paddle/platform/cuda_helper.h" +#include "paddle/platform/gpu_info.h" + +namespace paddle { +namespace operators { +using platform::PADDLE_CUDA_NUM_THREADS; + +template +__global__ void FillOutputKernel(const InT* p_in_data, OutT* p_out_data, + const int64_t numel, const int depth) { + int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < numel) { + *(p_out_data + (idx * depth) + p_in_data[idx]) = 1.0; + } +} + +template +struct OneHotOpCUDAFunctor { + const framework::LoDTensor* in_; + framework::LoDTensor* out_; + const DeviceContext& ctx_; + int depth_; + + OneHotOpCUDAFunctor(const framework::LoDTensor* in, framework::LoDTensor* out, + int depth, const DeviceContext& ctx) + : in_(in), out_(out), depth_(depth), ctx_(ctx) {} + + template + void operator()() const { + auto* p_in_data = in_->data(); + auto numel = in_->numel(); + auto* p_out_data = out_->mutable_data(ctx_.GetPlace()); + auto stream = ctx_.stream(); + math::set_constant(ctx_, out_, 0.0); + + FillOutputKernel<<<(numel + PADDLE_CUDA_NUM_THREADS - 1) / + PADDLE_CUDA_NUM_THREADS, + PADDLE_CUDA_NUM_THREADS, 0, stream>>>( + p_in_data, p_out_data, numel, depth_); + } +}; + +using LoDTensor = framework::LoDTensor; +template +class OneHotCUDAKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto* in = context.Input("X"); + auto* out = context.Output("Out"); + int depth = context.Attr("depth"); + + framework::VisitDataType( + static_cast(context.Attr("dtype")), + OneHotOpCUDAFunctor( + in, out, depth, context.template device_context())); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_CUDA_KERNEL( + one_hot, ops::OneHotCUDAKernel, + ops::OneHotCUDAKernel); diff --git a/paddle/operators/one_hot_op.h b/paddle/operators/one_hot_op.h new file mode 100644 index 00000000000..12031ede2c3 --- /dev/null +++ b/paddle/operators/one_hot_op.h @@ -0,0 +1,68 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "paddle/framework/op_registry.h" +#include "paddle/operators/math/math_function.h" + +namespace paddle { +namespace operators { + +template +struct OneHotOpFunctor { + const framework::LoDTensor* in_; + framework::LoDTensor* out_; + int depth_; + const DeviceContext& ctx_; + + OneHotOpFunctor(const framework::LoDTensor* in, framework::LoDTensor* out, + int depth, const DeviceContext& ctx) + : in_(in), out_(out), depth_(depth), ctx_(ctx) {} + + template + void operator()() const { + auto* p_in_data = in_->data(); + auto numel = in_->numel(); + auto* p_out_data = out_->mutable_data(ctx_.GetPlace()); + math::set_constant(ctx_, out_, 0.0); + + for (int i = 0; i < numel; ++i) { + PADDLE_ENFORCE_GE(p_in_data[i], 0, + "Illegal index value, should be at least 0."); + PADDLE_ENFORCE_LT(p_in_data[i], depth_, + "Illegal index value, should be less than depth (%d).", + depth_); + *(p_out_data + i * depth_ + p_in_data[i]) = 1.0; + } + } +}; + +using LoDTensor = framework::LoDTensor; +template +class OneHotKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto* in = context.Input("X"); + auto* out = context.Output("Out"); + int depth = context.Attr("depth"); + + framework::VisitDataType( + static_cast(context.Attr("dtype")), + OneHotOpFunctor( + in, out, depth, context.template device_context())); + } +}; + +} // namespace operators +} // namespace paddle diff --git a/python/paddle/v2/fluid/tests/test_one_hot_op.py b/python/paddle/v2/fluid/tests/test_one_hot_op.py new file mode 100644 index 00000000000..e51ea27d14d --- /dev/null +++ b/python/paddle/v2/fluid/tests/test_one_hot_op.py @@ -0,0 +1,110 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import numpy as np +import math +from op_test import OpTest +import paddle.v2.fluid as fluid +import paddle.v2.fluid.core as core +import paddle.v2.fluid.framework as framework +from paddle.v2.fluid.framework import Program, program_guard + + +class TestOneHotOp(OpTest): + def setUp(self): + self.op_type = 'one_hot' + depth = 10 + dimension = 12 + x_lod = [[0, 4, 5, 8, 11]] + x = [np.random.randint(0, depth - 1) for i in xrange(x_lod[0][-1])] + x = np.array(x).astype('int').reshape([x_lod[0][-1], 1]) + + out = np.zeros(shape=(np.product(x.shape[:-1]), + depth)).astype('float32') + + for i in xrange(np.product(x.shape)): + out[i, x[i]] = 1.0 + + self.inputs = {'X': (x, x_lod)} + self.attrs = {'depth': depth, 'dtype': int(core.DataType.FP32)} + self.outputs = {'Out': (out, x_lod)} + + def test_check_output(self): + self.check_output() + + +class TestOneHotOp_default_dtype(OpTest): + def setUp(self): + self.op_type = 'one_hot' + depth = 10 + dimension = 12 + x_lod = [[0, 4, 5, 8, 11]] + x = [np.random.randint(0, depth - 1) for i in xrange(x_lod[0][-1])] + x = np.array(x).astype('int').reshape([x_lod[0][-1], 1]) + + out = np.zeros(shape=(np.product(x.shape[:-1]), + depth)).astype('float32') + + for i in xrange(np.product(x.shape)): + out[i, x[i]] = 1.0 + + self.inputs = {'X': (x, x_lod)} + self.attrs = {'depth': depth} + self.outputs = {'Out': (out, x_lod)} + + def test_check_output(self): + self.check_output() + + +class TestOneHotOp_exception(OpTest): + def setUp(self): + self.op_type = 'one_hot' + self.depth = 10 + self.place = core.CPUPlace() + self.dimension = 12 + self.x = core.LoDTensor() + x_lod = [[0, 4, 5, 8, 11]] + data = [np.random.randint(11, 20) for i in xrange(x_lod[0][-1])] + data = np.array(data).astype('int').reshape([x_lod[0][-1], 1]) + self.x.set(data, self.place) + self.x.set_lod(x_lod) + + def test_check_output(self): + program = Program() + with program_guard(program): + x = fluid.layers.data( + name='x', shape=[self.dimension], dtype='float32', lod_level=1) + block = program.current_block() + one_hot_out = block.create_var( + name="one_hot_out", + type=core.VarDesc.VarType.LOD_TENSOR, + dtype='float32') + block.append_op( + type='one_hot', + inputs={'X': x}, + attrs={'depth': self.depth}, + outputs={'Out': one_hot_out}) + exe = fluid.Executor(self.place) + + def run(): + exe.run(feed={'x': self.x}, + fetch_list=[one_hot_out], + return_numpy=False) + + self.assertRaises(core.EnforceNotMet, run) + + +if __name__ == '__main__': + unittest.main() -- GitLab