未验证 提交 cee2ccb0 编写于 作者: Z zhoushiyu 提交者: GitHub

add shuffle batch op (#21674)

* add shuffle batch op, test=develop, test=document_preview

* fix size_t conflict and check_output test=develop, test=document_preview

* fix bug test=develop, test=document_preview

* add unittest of shuffle_batch layer test=develop, test=document_preview

* fix py coverage and op input type, test=develop, test=document_preview

* fix py coverage, test=develop

* fix en doc, test=develop

* move to contrib test=develop

* add unique_name test=develop

* invoke shuffle_batch in contrib.layers test=develop
上级 dca07583
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/shuffle_batch_op.h"
#include <memory>
#include "paddle/fluid/framework/no_need_buffer_vars_inference.h"
#include "paddle/fluid/framework/var_type_inference.h"
namespace paddle {
namespace operators {
class ShuffleBatchOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE_EQ(
ctx->HasInput("X"), true,
platform::errors::NotFound("Input(X) should not be null."));
PADDLE_ENFORCE_EQ(
ctx->HasInput("Seed"), true,
platform::errors::NotFound("Input(Seed) should not be null."));
PADDLE_ENFORCE_EQ(
ctx->HasOutput("Out"), true,
platform::errors::NotFound("Output(Out) should not be null."));
PADDLE_ENFORCE_EQ(
ctx->HasOutput("ShuffleIdx"), true,
platform::errors::NotFound("Output(ShuffleIdx) should not be null."));
PADDLE_ENFORCE_EQ(
ctx->HasOutput("SeedOut"), true,
platform::errors::NotFound("Output(SeedOut) should not be null."));
ctx->ShareDim("X", "Out");
ctx->ShareLoD("X", "Out");
ctx->ShareDim("Seed", "SeedOut");
ctx->ShareLoD("Seed", "SeedOut");
ctx->SetOutputDim("ShuffleIdx", framework::make_ddim({-1}));
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X");
return framework::OpKernelType(data_type, ctx.device_context());
}
};
class ShuffleBatchOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "(LoDTensor) The input tensor of shuffle_batch op.");
AddInput("Seed", "(LoDTensor) The input seed tensor.");
AddAttr<int>(
"startup_seed",
"If input tensor 'Seed' is not initialized, the 'startup_seed' "
"will be used to replace it. The seed after shuffle batch will "
"be saved in 'SeedOut'. ")
.SetDefault(0);
AddOutput("Out", "(LoDTensor) The output tensor of shuffle_batch op.");
AddOutput("ShuffleIdx", "(Tensor) Record forword shuffle order");
AddOutput("SeedOut", "(LoDTensor) Saved new generated seed.");
AddComment(R"DOC(
Shuffle Batch Operator.
This operator is used to shuffle input $X$'s elements.
There is 2 input. The product of input dims (except last dim) numbers of elements will be shuffled. $Seed$ is tensor of seed.
There are 3 outputs. $Out$ is shuffled tensor of input. $ShuffleIdx$ is the tensor used to record shuffle order. $SeedOut$ is same tensor of $Seed$.
)DOC");
}
};
class ShuffleBatchOpGrad : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE_EQ(
ctx->HasInput("ShuffleIdx"), true,
platform::errors::NotFound("Input(ShuffleIdx) should not be null"));
PADDLE_ENFORCE_EQ(
ctx->HasInput(framework::GradVarName("Out")), true,
platform::errors::NotFound("Grad Input(Out) should not be null"));
PADDLE_ENFORCE_EQ(
ctx->HasOutput(framework::GradVarName("X")), true,
platform::errors::NotFound("Grad Output(X) should not be null"));
ctx->ShareDim(framework::GradVarName("Out"), framework::GradVarName("X"));
ctx->ShareLoD(framework::GradVarName("Out"), framework::GradVarName("X"));
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
auto data_type = OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("Out"));
return framework::OpKernelType(data_type, ctx.device_context());
}
};
template <typename T>
class ShuffleBatchGradOpMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
std::unique_ptr<T> Apply() const override {
std::unique_ptr<T> op(new T());
op->SetType("shuffle_batch_grad");
op->SetInput("ShuffleIdx", this->Output("ShuffleIdx"));
op->SetAttrMap(this->Attrs());
op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
return op;
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(shuffle_batch, ops::ShuffleBatchOp, ops::ShuffleBatchOpMaker,
ops::ShuffleBatchGradOpMaker<paddle::framework::OpDesc>,
ops::ShuffleBatchGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(shuffle_batch_grad, ops::ShuffleBatchOpGrad);
REGISTER_OP_CPU_KERNEL(shuffle_batch, ops::ShuffleBatchKernel<float>,
ops::ShuffleBatchKernel<double>,
ops::ShuffleBatchKernel<int32_t>,
ops::ShuffleBatchKernel<int64_t>);
REGISTER_OP_CPU_KERNEL(shuffle_batch_grad, ops::ShuffleBatchGradKernel<float>,
ops::ShuffleBatchGradKernel<double>,
ops::ShuffleBatchGradKernel<int32_t>,
ops::ShuffleBatchGradKernel<int64_t>);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <atomic>
#include <cstring>
#include <ctime>
#include <random>
#include <string>
#include <vector>
#include "glog/logging.h"
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/mixed_vector.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/memory/memcpy.h"
#include "paddle/fluid/platform/timer.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
using LoDTensor = framework::LoDTensor;
#if defined(PADDLE_WITH_CUDA)
template <typename T>
using Vector = framework::Vector<T>;
#else
template <typename T>
using Vector = framework::CPUVector<T>;
#endif
template <typename T>
class ShuffleBatchKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &context) const override {
auto *x = context.Input<LoDTensor>("X");
auto *seed = context.Input<LoDTensor>("Seed");
auto *out = context.Output<LoDTensor>("Out");
auto *shuffleidx = context.Output<LoDTensor>("ShuffleIdx");
auto *seed_out = context.Output<LoDTensor>("SeedOut");
auto x_embed_size = x->dims()[x->dims().size() - 1];
auto elem_size = 1;
for (auto i = 0; i < x->dims().size() - 1; i++) elem_size *= x->dims()[i];
std::vector<int64_t> idx_vec; // record shuffled order
idx_vec.reserve(elem_size);
for (auto i = 0; i < elem_size; i++) {
idx_vec.push_back(i);
}
int64_t seed_int = 0;
if (seed->IsInitialized()) {
seed_int = *seed->data<int64_t>();
} else {
seed_int = context.Attr<int>("startup_seed");
}
std::default_random_engine engine;
engine.seed(seed_int);
std::shuffle(idx_vec.begin(), idx_vec.end(), engine);
// ShuffleIdx record shuffle order
shuffleidx->Resize(framework::make_ddim({(int64_t)idx_vec.size()}));
auto *shuffleidx_data =
shuffleidx->mutable_data<int64_t>(context.GetPlace());
for (size_t i = 0; i < idx_vec.size(); i++) {
shuffleidx_data[i] = idx_vec[i];
}
// copy data according to idx_vec
auto *x_data = x->data<T>();
auto *out_data = out->mutable_data<T>(context.GetPlace());
for (auto i = 0; i < elem_size; i++) {
memcpy(out_data + idx_vec[i] * x_embed_size, x_data + i * x_embed_size,
x_embed_size * sizeof(T));
}
// set new seed
*seed_out->mutable_data<int64_t>(framework::make_ddim({1}),
context.GetPlace()) = engine();
}
};
template <typename T>
class ShuffleBatchGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &context) const override {
auto *out_grad = context.Input<LoDTensor>(framework::GradVarName("Out"));
auto *shuffleidx = context.Input<LoDTensor>("ShuffleIdx");
auto *x_grad = context.Output<LoDTensor>(framework::GradVarName("X"));
auto embed_size = out_grad->dims()[out_grad->dims().size() - 1];
auto elem_size = 1;
for (auto i = 0; i < out_grad->dims().size() - 1; i++)
elem_size *= out_grad->dims()[i];
std::vector<int> idx_vec_grad(elem_size);
auto *shuffleidx_data = shuffleidx->data<int64_t>();
for (size_t i = 0; i < idx_vec_grad.size(); i++) {
idx_vec_grad[shuffleidx_data[i]] = i;
}
// copy data according to idx_vec_grad
auto *out_grad_data = out_grad->data<T>();
auto *x_grad_data = x_grad->mutable_data<T>(context.GetPlace());
for (auto i = 0; i < elem_size; i++) {
memcpy(x_grad_data + idx_vec_grad[i] * embed_size,
out_grad_data + i * embed_size, embed_size * sizeof(T));
}
}
};
} // namespace operators
} // namespace paddle
......@@ -23,6 +23,7 @@ import os
import inspect
from paddle.fluid.layer_helper import LayerHelper
from paddle.fluid.layers import utils
from ... import unique_name
__all__ = [
'fused_elemwise_activation',
......@@ -33,6 +34,7 @@ __all__ = [
'fused_embedding_seq_pool',
'multiclass_nms2',
'search_pyramid_hash',
'shuffle_batch',
]
......@@ -722,3 +724,67 @@ def search_pyramid_hash(input,
})
return res
def shuffle_batch(x, seed=None):
"""
This layer shuffle input tensor :attr:`x` . Normally, :attr:`x` is 2-D LoDTensor.
:attr:`x` is a LoDTensor to be shuffled with shape :math:`[N_1, N_2, ..., N_k, D]` . Note that the last dim of input will not be shuffled.
:math:`N_1 * N_2 * ... * N_k` numbers of elements with length :math:`D` will be shuffled randomly.
For Example:
.. code-block:: text
Input:
x.data = [[1, 2], [3, 4], [5, 6], [7, 8]]
x.dims = [4, 2]
Attrs:
seed = 2019
Output:
Out.data =[[7, 8], [1, 2], [3, 4], [5, 6]]
Out.dims = [4, 2]
Args:
x (Variable): The input variable. The input variable is a N-D LoDTensor with type int, float32 or float64.
seed (None|int|Variable): The start up seed. If set, seed will be set as the start up seed of shuffle engine.
If not set(Default), start up seed of shuffle engine will be generated randomly.
Returns:
Variables: The shuffled LoDTensor with the same shape and lod as input.
Examples:
.. code-block:: python
import paddle.fluid as fluid
x = fluid.layers.data(name="x", shape=[-1, 4])
out = fluid.contrib.layers.shuffle_batch(x)
"""
helper = LayerHelper('shuffle_batch', **locals())
out = helper.create_variable_for_type_inference(dtype=x.dtype)
shuffle_idx = helper.create_variable_for_type_inference(dtype=np.int64)
if seed is None and helper.main_program.random_seed != 0:
seed = helper.main_program.random_seed
if seed is None:
seed = np.random.randint(-65536, 65535)
op_attrs = {}
if isinstance(seed, int):
op_attrs["startup_seed"] = seed
seed = helper.create_variable(
name=unique_name.generate("shuffle_batch_seed"),
dtype="int64",
persistable=True)
helper.append_op(
type='shuffle_batch',
inputs={'X': x,
'Seed': seed},
outputs={'Out': out,
'ShuffleIdx': shuffle_idx,
'SeedOut': seed},
attrs=op_attrs)
return out
......@@ -2828,6 +2828,18 @@ class TestBook(LayerTest):
name='Filter_tag')
out1, out2 = layers.filter_by_instag(x1, x2, x3, is_lod=True)
def test_shuffle_batch(self):
# TODO(minqiyang): dygraph do not support lod now
with self.static_graph():
x = layers.data(
name='X', shape=[4, 50], dtype='float32', lod_level=0)
out1 = fluid.contrib.layers.shuffle_batch(x)
default_main_program().random_seed = 1000
out2 = fluid.contrib.layers.shuffle_batch(x)
self.assertIsNotNone(out1)
self.assertIsNotNone(out2)
return (out1)
def test_roi_pool(self):
# TODO(minqiyang): dygraph do not support lod now
with self.static_graph():
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""This is unit test of Test shuffle_batch Op."""
from __future__ import print_function, division
import unittest
import numpy as np
import paddle.fluid as fluid
import paddle.fluid.core as core
import paddle.fluid.layers as layers
from op_test import OpTest
import random
class TestShuffleBatchOp(OpTest):
def setUp(self):
self.op_type = 'shuffle_batch'
self.dtype = np.float64
x = np.array(
[np.arange(100), np.arange(100)]).astype(self.dtype).reshape(
[2, 100])
out = np.array(
[np.arange(100), np.arange(100)]).astype(self.dtype).reshape(
[2, 100])
self.possible_res = [
np.array([np.arange(100), np.arange(100)]).astype(self.dtype),
]
self.inputs = {'X': x, 'Seed': np.array([1]).astype('int64')}
self.outputs = {
'Out': out,
'ShuffleIdx': np.array([1, 0]).astype('int64'),
'SeedOut': np.array([1]).astype('int64')
}
self.attrs = {'startup_seed': 1}
def test_check_output(self):
self.check_output_customized(self.verify_output)
def verify_output(self, outs):
for elem in outs:
if elem.shape == self.outputs['Out'].shape:
out = elem
break
is_equal = [np.all(out == res) for res in self.possible_res]
self.assertIn(True, is_equal)
def test_check_grad(self):
self.check_grad(['X'], 'Out')
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册