未验证 提交 599a3264 编写于 作者: Z Zeng Jinle 提交者: GitHub

Merge pull request #12971 from sneaxiy/unstack_op

Add unstack op
......@@ -164,6 +164,7 @@ paddle.fluid.layers.prelu ArgSpec(args=['x', 'mode', 'param_attr', 'name'], vara
paddle.fluid.layers.flatten ArgSpec(args=['x', 'axis', 'name'], varargs=None, keywords=None, defaults=(1, None))
paddle.fluid.layers.sequence_mask ArgSpec(args=['x', 'maxlen', 'dtype', 'name'], varargs=None, keywords=None, defaults=(None, 'int64', None))
paddle.fluid.layers.stack ArgSpec(args=['x', 'axis'], varargs=None, keywords=None, defaults=(0,))
paddle.fluid.layers.unstack ArgSpec(args=['x', 'axis', 'num'], varargs=None, keywords=None, defaults=(0, None))
paddle.fluid.layers.data ArgSpec(args=['name', 'shape', 'append_batch_size', 'dtype', 'lod_level', 'type', 'stop_gradient'], varargs=None, keywords=None, defaults=(True, 'float32', 0, VarType.LOD_TENSOR, True))
paddle.fluid.layers.open_recordio_file ArgSpec(args=['filename', 'shapes', 'lod_levels', 'dtypes', 'pass_num', 'for_parallel'], varargs=None, keywords=None, defaults=(1, True))
paddle.fluid.layers.open_files ArgSpec(args=['filenames', 'shapes', 'lod_levels', 'dtypes', 'thread_num', 'buffer_size', 'pass_num', 'is_test'], varargs=None, keywords=None, defaults=(None, None, 1, None))
......
......@@ -291,6 +291,7 @@ op_library(unsqueeze_op DEPS reshape_op)
op_library(squeeze_op DEPS reshape_op)
op_library(extract_rows_op DEPS memory)
op_library(flatten_op DEPS reshape_op)
op_library(unstack_op DEPS stack_op)
if (WITH_GPU)
op_library(conv_op DEPS vol2col depthwise_conv im2col)
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/unstack_op.h"
namespace plat = paddle::platform;
namespace ops = paddle::operators;
USE_OP(stack);
REGISTER_OPERATOR(unstack, ops::UnStackOp, ops::UnStackOpMaker,
ops::UnStackOpInferShape, ops::UnStackGradOpDescMaker);
REGISTER_OPERATOR(unstack_grad, ops::UnStackGradOp,
ops::UnStackOpGradInferShape);
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/op_registry.h"
namespace paddle {
namespace operators {
class UnStackOpInferShape : public framework::InferShapeBase {
public:
void operator()(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) must exist.");
int axis = ctx->Attrs().Get<int>("axis");
int num = ctx->Attrs().Get<int>("num");
auto x_dim = ctx->GetInputDim("X");
int rank = x_dim.size();
PADDLE_ENFORCE(axis >= -rank && axis < rank,
"Attr(axis) must be inside [-rank, rank), where rank = %d",
rank);
if (axis < 0) axis += rank;
PADDLE_ENFORCE_EQ(ctx->Outputs("Y").size(), static_cast<size_t>(num),
"Number of Outputs(Y) is wrong");
if (x_dim[axis] > 0) {
PADDLE_ENFORCE_EQ(num, x_dim[axis], "Number of Outputs(Y) is wrong");
}
auto vec = framework::vectorize2int(x_dim);
vec.erase(vec.begin() + axis);
ctx->SetOutputsDim("Y", std::vector<framework::DDim>( // NOLINT
x_dim[axis], framework::make_ddim(vec)));
}
};
class UnStackOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "The input of unstack op.");
AddOutput("Y", "The output of unstack op.").AsDuplicable();
AddAttr<int>("axis", "The axis along which Input(X) should be unstacked.")
.SetDefault(0);
AddAttr<int>("num", "The number of outputs(Y).").GreaterThan(0);
AddComment(R"DOC(
UnStack Operator.
UnStack Input(X) into several tensors along Attr(axis).
)DOC");
}
};
class UnStackOp : public framework::OperatorBase {
public:
using OperatorBase::OperatorBase;
private:
void RunImpl(const framework::Scope &scope,
const platform::Place &place) const override {
auto stack_grad_op = framework::OpRegistry::CreateOp(
"stack_grad", {{framework::GradVarName("Y"), {Input("X")}}},
{{framework::GradVarName("X"), Outputs("Y")}}, Attrs());
stack_grad_op->Run(scope, place);
}
};
class UnStackOpGradInferShape : public framework::InferShapeBase {
public:
void operator()(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE_GT(ctx->Inputs(framework::GradVarName("Y")).size(), 0,
"Number of Inputs(Y@Grad) must be larger than 0");
PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("X")),
"Output(X@Grad) must exist.");
auto input_dims = ctx->GetInputsDim(framework::GradVarName("Y"));
for (size_t i = 1; i < input_dims.size(); ++i) {
PADDLE_ENFORCE_EQ(input_dims[i], input_dims[0],
"Dims of all Inputs(Y@Grad) must be the same");
}
int axis = ctx->Attrs().Get<int>("axis");
int rank = input_dims[0].size();
PADDLE_ENFORCE(
axis >= -(rank + 1) && axis < rank + 1,
"Attr(axis) must be inside [-(rank+1), rank+1), where rank = %d", rank);
if (axis < 0) axis += (rank + 1);
auto vec = framework::vectorize2int(input_dims[0]);
vec.insert(vec.begin() + axis, input_dims.size());
ctx->SetOutputDim(framework::GradVarName("X"), framework::make_ddim(vec));
}
};
class UnStackGradOpDescMaker : public framework::SingleGradOpDescMaker {
public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
protected:
std::unique_ptr<framework::OpDesc> Apply() const override {
std::unique_ptr<framework::OpDesc> op(new framework::OpDesc());
op->SetType("unstack_grad");
op->SetInput(framework::GradVarName("Y"), OutputGrad("Y"));
op->SetOutput(framework::GradVarName("X"), InputGrad("X"));
op->SetAttrMap(Attrs());
return op;
}
};
class UnStackGradOp : public framework::OperatorBase {
public:
using OperatorBase::OperatorBase;
private:
void RunImpl(const framework::Scope &scope,
const platform::Place &place) const override {
auto stack_op = framework::OpRegistry::CreateOp(
"stack", {{"X", Inputs(framework::GradVarName("Y"))}},
{{"Y", {Output(framework::GradVarName("X"))}}}, Attrs());
stack_op->Run(scope, place);
}
};
} // namespace operators
} // namespace paddle
......@@ -105,6 +105,7 @@ __all__ = [
'flatten',
'sequence_mask',
'stack',
'unstack',
]
......@@ -5601,3 +5602,44 @@ def stack(x, axis=0):
type='stack', inputs={'X': x}, outputs={'Y': out},
attrs={'axis': axis})
return out
def unstack(x, axis=0, num=None):
"""
**UnStack Layer**
This layer unstacks input :code:`x` into several tensors along axis.
If :code:`axis` < 0, it would be replaced with :code:`axis+rank(x)`.
If :code:`num` is None, it would be inferred from :code:`x.shape[axis]`,
and if :code:`x.shape[axis]` <= 0 or is unknown, :code:`ValueError` is
raised.
Args:
x (Variable): Input variable.
axis (int): The axis along which the input is unstacked.
num (int|None): The number of output variables.
Returns:
list(Variable): The unstacked variables.
"""
helper = LayerHelper('unstack', **locals())
if num is None:
if axis is None or x.shape[axis] <= 0:
raise ValueError('unknown unstack number')
else:
num = x.shape[axis]
outs = []
for _ in num:
outs.append(helper.create_tmp_variable(x.dtype))
helper.append_op(
type='unstack',
inputs={'X': [x]},
outputs={'Y': outs},
attrs={'axis': axis,
'num': num})
return outs
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from op_test import OpTest
import numpy as np
import unittest
class TestUnStackOpBase(OpTest):
def initDefaultParameters(self):
self.input_dim = (5, 6, 7)
self.axis = 0
self.dtype = 'float32'
def initParameters(self):
pass
def get_y_names(self):
y_names = []
for i in range(self.input_dim[self.axis]):
y_names.append('y{}'.format(i))
return y_names
def setUp(self):
self.initDefaultParameters()
self.initParameters()
self.op_type = 'unstack'
self.x = np.random.random(size=self.input_dim).astype(self.dtype)
outs = np.split(self.x, self.input_dim[self.axis], self.axis)
new_shape = list(self.input_dim)
del new_shape[self.axis]
y_names = self.get_y_names()
tmp = []
for i in range(self.input_dim[self.axis]):
tmp.append((y_names[i], np.reshape(outs[i], new_shape)))
self.inputs = {'X': self.x}
self.outputs = {'Y': tmp}
self.attrs = {'axis': self.axis, 'num': self.input_dim[self.axis]}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad('X', self.get_y_names())
class TestStackOp3(TestUnStackOpBase):
def initParameters(self):
self.axis = -1
class TestStackOp4(TestUnStackOpBase):
def initParameters(self):
self.axis = -3
class TestStackOp5(TestUnStackOpBase):
def initParameters(self):
self.axis = 1
class TestStackOp6(TestUnStackOpBase):
def initParameters(self):
self.axis = 2
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册