提交 84c12c6e 编写于 作者: Y Yang yaming 提交者: Abhinav Arora

Add one_hot operator. (#7819)

* Add one_hot operator.

* Add more unit tests.
上级 788f5c6d
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/operators/one_hot_op.h"
#include "paddle/framework/framework.pb.h"
namespace paddle {
namespace operators {
class OneHotOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of OneHotOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of OneHotOp should not be null.");
auto x_dims = ctx->GetInputDim("X");
PADDLE_ENFORCE_GE(x_dims.size(), 2,
"Rank of Input(X) should be at least 2.");
PADDLE_ENFORCE_GE(x_dims[x_dims.size() - 1], 1U,
"Last dimension of Input(X) should be 1.");
int depth = ctx->Attrs().Get<int>("depth");
PADDLE_ENFORCE_GT(depth, 0, "Should provide a positive depth (%d).", depth);
framework::DDim out_dims(x_dims);
out_dims[out_dims.size() - 1] = depth;
ctx->SetOutputDim("Out", out_dims);
ctx->ShareLoD("X", /* --> */ "Out");
}
};
class OneHotOpMaker : public framework::OpProtoAndCheckerMaker {
public:
OneHotOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X",
"(LoDTensor, LoDTensor<int>) Input variable with rank at least 2. "
"The last dimension of X should be 1. Each value of X is an index "
"to indicate the position.");
AddOutput("Out",
"(Tensor, Tensor<float>) Output tensor with same rank as X. "
"The tensor consists of one-hot representations of values in X.");
AddAttr<int>("depth",
"A positive integer to specify the length of one-hot vector.");
AddAttr<int>("dtype",
"An integer to specify the data type of one-hot "
"vector. The default value is FP32.")
.SetDefault(paddle::framework::proto::DataType::FP32);
AddComment(R"DOC(
One Hot Operator. This operator creates the one-hot representations for input
index values. The following example will help to explain the function of this
operator:
X is a LoDTensor:
X.lod = [[0, 1, 4]]
X.shape = [4, 1]
X.data = [[1], [1], [3], [0]]
set depth = 4
Out is a LoDTensor:
Out.lod = [[0, 1, 4]]
Out.shape = [4, 4]
Out.data = [[0., 1., 0., 0.],
[0., 1., 0., 0.],
[0., 0., 0., 1.],
[1., 0., 0., 0.]]
)DOC");
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(one_hot, ops::OneHotOp, ops::OneHotOpMaker,
paddle::framework::EmptyGradOpMaker);
REGISTER_OP_CPU_KERNEL(
one_hot, ops::OneHotKernel<paddle::platform::CPUDeviceContext, int>,
ops::OneHotKernel<paddle::platform::CPUDeviceContext, int64_t>);
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/operators/one_hot_op.h"
#include "paddle/platform/cuda_helper.h"
#include "paddle/platform/gpu_info.h"
namespace paddle {
namespace operators {
using platform::PADDLE_CUDA_NUM_THREADS;
template <typename InT, typename OutT>
__global__ void FillOutputKernel(const InT* p_in_data, OutT* p_out_data,
const int64_t numel, const int depth) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < numel) {
*(p_out_data + (idx * depth) + p_in_data[idx]) = 1.0;
}
}
template <typename DeviceContext, typename InT>
struct OneHotOpCUDAFunctor {
const framework::LoDTensor* in_;
framework::LoDTensor* out_;
const DeviceContext& ctx_;
int depth_;
OneHotOpCUDAFunctor(const framework::LoDTensor* in, framework::LoDTensor* out,
int depth, const DeviceContext& ctx)
: in_(in), out_(out), depth_(depth), ctx_(ctx) {}
template <typename OutT>
void operator()() const {
auto* p_in_data = in_->data<InT>();
auto numel = in_->numel();
auto* p_out_data = out_->mutable_data<OutT>(ctx_.GetPlace());
auto stream = ctx_.stream();
math::set_constant(ctx_, out_, 0.0);
FillOutputKernel<<<(numel + PADDLE_CUDA_NUM_THREADS - 1) /
PADDLE_CUDA_NUM_THREADS,
PADDLE_CUDA_NUM_THREADS, 0, stream>>>(
p_in_data, p_out_data, numel, depth_);
}
};
using LoDTensor = framework::LoDTensor;
template <typename DeviceContext, typename T>
class OneHotCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* in = context.Input<LoDTensor>("X");
auto* out = context.Output<LoDTensor>("Out");
int depth = context.Attr<int>("depth");
framework::VisitDataType(
static_cast<framework::proto::DataType>(context.Attr<int>("dtype")),
OneHotOpCUDAFunctor<DeviceContext, T>(
in, out, depth, context.template device_context<DeviceContext>()));
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
one_hot, ops::OneHotCUDAKernel<paddle::platform::CUDADeviceContext, int>,
ops::OneHotCUDAKernel<paddle::platform::CUDADeviceContext, int64_t>);
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/framework/op_registry.h"
#include "paddle/operators/math/math_function.h"
namespace paddle {
namespace operators {
template <typename DeviceContext, typename InT>
struct OneHotOpFunctor {
const framework::LoDTensor* in_;
framework::LoDTensor* out_;
int depth_;
const DeviceContext& ctx_;
OneHotOpFunctor(const framework::LoDTensor* in, framework::LoDTensor* out,
int depth, const DeviceContext& ctx)
: in_(in), out_(out), depth_(depth), ctx_(ctx) {}
template <typename OutT>
void operator()() const {
auto* p_in_data = in_->data<InT>();
auto numel = in_->numel();
auto* p_out_data = out_->mutable_data<OutT>(ctx_.GetPlace());
math::set_constant(ctx_, out_, 0.0);
for (int i = 0; i < numel; ++i) {
PADDLE_ENFORCE_GE(p_in_data[i], 0,
"Illegal index value, should be at least 0.");
PADDLE_ENFORCE_LT(p_in_data[i], depth_,
"Illegal index value, should be less than depth (%d).",
depth_);
*(p_out_data + i * depth_ + p_in_data[i]) = 1.0;
}
}
};
using LoDTensor = framework::LoDTensor;
template <typename DeviceContext, typename T>
class OneHotKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* in = context.Input<LoDTensor>("X");
auto* out = context.Output<LoDTensor>("Out");
int depth = context.Attr<int>("depth");
framework::VisitDataType(
static_cast<framework::proto::DataType>(context.Attr<int>("dtype")),
OneHotOpFunctor<DeviceContext, T>(
in, out, depth, context.template device_context<DeviceContext>()));
}
};
} // namespace operators
} // namespace paddle
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import math
from op_test import OpTest
import paddle.v2.fluid as fluid
import paddle.v2.fluid.core as core
import paddle.v2.fluid.framework as framework
from paddle.v2.fluid.framework import Program, program_guard
class TestOneHotOp(OpTest):
def setUp(self):
self.op_type = 'one_hot'
depth = 10
dimension = 12
x_lod = [[0, 4, 5, 8, 11]]
x = [np.random.randint(0, depth - 1) for i in xrange(x_lod[0][-1])]
x = np.array(x).astype('int').reshape([x_lod[0][-1], 1])
out = np.zeros(shape=(np.product(x.shape[:-1]),
depth)).astype('float32')
for i in xrange(np.product(x.shape)):
out[i, x[i]] = 1.0
self.inputs = {'X': (x, x_lod)}
self.attrs = {'depth': depth, 'dtype': int(core.DataType.FP32)}
self.outputs = {'Out': (out, x_lod)}
def test_check_output(self):
self.check_output()
class TestOneHotOp_default_dtype(OpTest):
def setUp(self):
self.op_type = 'one_hot'
depth = 10
dimension = 12
x_lod = [[0, 4, 5, 8, 11]]
x = [np.random.randint(0, depth - 1) for i in xrange(x_lod[0][-1])]
x = np.array(x).astype('int').reshape([x_lod[0][-1], 1])
out = np.zeros(shape=(np.product(x.shape[:-1]),
depth)).astype('float32')
for i in xrange(np.product(x.shape)):
out[i, x[i]] = 1.0
self.inputs = {'X': (x, x_lod)}
self.attrs = {'depth': depth}
self.outputs = {'Out': (out, x_lod)}
def test_check_output(self):
self.check_output()
class TestOneHotOp_exception(OpTest):
def setUp(self):
self.op_type = 'one_hot'
self.depth = 10
self.place = core.CPUPlace()
self.dimension = 12
self.x = core.LoDTensor()
x_lod = [[0, 4, 5, 8, 11]]
data = [np.random.randint(11, 20) for i in xrange(x_lod[0][-1])]
data = np.array(data).astype('int').reshape([x_lod[0][-1], 1])
self.x.set(data, self.place)
self.x.set_lod(x_lod)
def test_check_output(self):
program = Program()
with program_guard(program):
x = fluid.layers.data(
name='x', shape=[self.dimension], dtype='float32', lod_level=1)
block = program.current_block()
one_hot_out = block.create_var(
name="one_hot_out",
type=core.VarDesc.VarType.LOD_TENSOR,
dtype='float32')
block.append_op(
type='one_hot',
inputs={'X': x},
attrs={'depth': self.depth},
outputs={'Out': one_hot_out})
exe = fluid.Executor(self.place)
def run():
exe.run(feed={'x': self.x},
fetch_list=[one_hot_out],
return_numpy=False)
self.assertRaises(core.EnforceNotMet, run)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册