未验证 提交 eb035f24 编写于 作者: myq406450149's avatar myq406450149 提交者: GitHub

add unbind op (#23359)

* add unbind op
unbind(tensor, dim=0):
说明:移除指定维后,返回一组数组,包含了沿着指定维切片后的各个切片。

tensor(Tensor) -- 输入Tensor
dim(int) -- 删除的维度

示例:
 Input = [[1,2],
           [3,4],
           [5,6]]
  axis = 0
  Output[0] = [1,2]
  Output[1] = [3,4]
  Output[2] = [5,6]
上级 fd9b7bdb
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/unbind_op.h"
#include <string>
namespace paddle {
namespace operators {
using framework::Tensor;
class UnbindOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE_EQ(
ctx->HasInput("X"), true,
platform::errors::NotFound("Input(X) of UnbindOp is not found."));
PADDLE_ENFORCE_GE(
ctx->Outputs("Out").size(), 1UL,
platform::errors::NotFound("Outputs(Out) of UnbindOp is not found."));
auto in_dims = ctx->GetInputDim("X");
auto outs_names = ctx->Outputs("Out");
int axis = ctx->Attrs().Get<int>("axis");
const size_t outs_number = outs_names.size();
auto out_dims = UnbindOutsDims(in_dims, axis);
std::vector<framework::DDim> outs_dims(outs_number, out_dims);
ctx->SetOutputsDim("Out", outs_dims);
for (size_t i = 0; i < outs_number; ++i) {
ctx->ShareLoD("X", "Out", 0, i);
}
}
};
class UnbindOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "(Tensor) Input tensor of the split operator.");
AddOutput("Out", "(Tensor) Output tensors of the unbind operator.")
.AsDuplicable();
AddComment(R"DOC(
Unbind operator
Remove a tensor dimension.
Example:
Input = [[1,2],
[3,4],
[5,6]]
axis = 0
Output[0] = [1,2]
Output[1] = [3,4]
Output[2] = [5,6]
)DOC");
AddAttr<int>("axis",
"(int, default 0) "
"dimension to remove.")
.SetDefault(0);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(unbind, ops::UnbindOp, ops::UnbindOpMaker,
ops::UnbindGradMaker<paddle::framework::OpDesc>,
ops::UnbindGradMaker<paddle::imperative::OpBase>);
namespace plat = paddle::platform;
REGISTER_OP_CPU_KERNEL(
unbind, ops::UnbindOpKernel<plat::CPUDeviceContext, double>,
ops::UnbindOpKernel<plat::CPUDeviceContext, float>,
ops::UnbindOpKernel<plat::CPUDeviceContext, int64_t>,
ops::UnbindOpKernel<plat::CPUDeviceContext, int>,
ops::UnbindOpKernel<plat::CPUDeviceContext, plat::float16>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/unbind_op.h"
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(
unbind, ops::UnbindOpKernel<plat::CUDADeviceContext, double>,
ops::UnbindOpKernel<plat::CUDADeviceContext, float>,
ops::UnbindOpKernel<plat::CUDADeviceContext, int64_t>,
ops::UnbindOpKernel<plat::CUDADeviceContext, int>,
ops::UnbindOpKernel<plat::CUDADeviceContext, plat::float16>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <chrono> // NOLINT
#include <memory>
#include <string>
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/concat_and_split.h"
#include "paddle/fluid/operators/strided_memcpy.h"
#include "paddle/fluid/operators/utils.h"
namespace paddle {
namespace operators {
static inline framework::DDim UnbindOutsDims(const framework::DDim in_dims,
int axis) {
std::vector<int> out_dims;
axis = axis < 0 ? in_dims.size() + axis : axis;
for (int i = 0; i < in_dims.size(); i++) {
if (i != axis) out_dims.push_back(in_dims[i]);
}
return framework::make_ddim(out_dims);
}
template <typename DeviceContext, typename T>
class UnbindOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* in = ctx.Input<framework::Tensor>("X");
auto outs = ctx.MultiOutput<framework::Tensor>("Out");
int axis = ctx.Attr<int>("axis");
auto in_dims = in->dims();
auto place = ctx.GetPlace();
axis = axis < 0 ? in_dims.size() + axis : axis;
std::vector<const framework::Tensor*> shape_refer;
for (size_t j = 0; j < outs.size(); ++j) {
outs[j]->mutable_data<T>(ctx.GetPlace());
shape_refer.emplace_back(outs[j]);
}
auto& dev_ctx = ctx.template device_context<DeviceContext>();
math::SplitFunctor<DeviceContext, T> functor;
functor(dev_ctx, *in, shape_refer, axis, &outs);
}
};
template <typename T>
class UnbindGradMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
void Apply(GradOpPtr<T> op) const override {
op->SetType("stack");
op->SetInput("X", this->OutputGrad("Out"));
op->SetOutput("Y", this->InputGrad("X"));
op->SetAttrMap(this->Attrs());
}
};
} // namespace operators
} // namespace paddle
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
from op_test import OpTest
import paddle.fluid as fluid
from paddle.fluid import compiler, Program, program_guard, core
class TestUnbindOp(OpTest):
def initParameters(self):
pass
def outReshape(self):
pass
def setAxis(self):
pass
def setUp(self):
self._set_op_type()
self.dtype = self.get_dtype()
self.axis = 0
self.num = 3
self.initParameters()
#x = np.random.random((3, 2, 2)).astype(self.dtype)
x = np.arange(12).reshape(3, 2, 2).astype(self.dtype)
self.out = np.split(x, self.num, self.axis)
self.outReshape()
self.inputs = {'X': x}
self.attrs = {'axis': self.axis}
self.setAxis()
self.outputs = {'Out': [('out%d' % i, self.out[i]) \
for i in range(len(self.out))]}
def get_dtype(self):
return "float64"
def _set_op_type(self):
self.op_type = "unbind"
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(['X'], ['out0', 'out1', 'out2'])
class TestUnbindOp1(TestUnbindOp):
def initParameters(self):
self.axis = 1
self.num = 2
def test_check_grad(self):
self.check_grad(['X'], ['out0', 'out1'])
def outReshape(self):
self.out[0] = self.out[0].reshape((3, 2))
self.out[1] = self.out[1].reshape((3, 2))
class TestUnbindOp2(TestUnbindOp):
def initParameters(self):
self.axis = 2
self.num = 2
def test_check_grad(self):
self.check_grad(['X'], ['out0', 'out1'])
def outReshape(self):
self.out[0] = self.out[0].reshape((3, 2))
self.out[1] = self.out[1].reshape((3, 2))
class TestUnbindOp3(TestUnbindOp):
def initParameters(self):
self.axis = 2
self.num = 2
def setAxis(self):
self.attrs = {'axis': -1}
def test_check_grad(self):
self.check_grad(['X'], ['out0', 'out1'])
def outReshape(self):
self.out[0] = self.out[0].reshape((3, 2))
self.out[1] = self.out[1].reshape((3, 2))
class TestUnbindOp4(TestUnbindOp):
def initParameters(self):
self.axis = 1
self.num = 2
def setAxis(self):
self.attrs = {'axis': -2}
def test_check_grad(self):
self.check_grad(['X'], ['out0', 'out1'])
def outReshape(self):
self.out[0] = self.out[0].reshape((3, 2))
self.out[1] = self.out[1].reshape((3, 2))
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册