提交 688ed601 编写于 作者: L li099 提交者: whs

Add lod tensor array to tensor op (#13990)

* add lod tensor array concat

* add lod tensor array concat

* test=develop

* add lod tensor array concat
test=develop

* Fix API.spec
test=develop

* add lod tensor array concat
test=develop

* revise some bug of lod tensor array concat
test=develop

* add unittest for tensor array concat
test=develop

* change to tensor array to tensor
test=develop

* revise bug
test=develop

* revise a bug
test=develop

* revise a bug
test=develop

* revise a bug of python3
test=develop
上级 6c6e6385
......@@ -201,6 +201,7 @@ paddle.fluid.layers.create_tensor ArgSpec(args=['dtype', 'name', 'persistable'],
paddle.fluid.layers.create_parameter ArgSpec(args=['shape', 'dtype', 'name', 'attr', 'is_bias', 'default_initializer'], varargs=None, keywords=None, defaults=(None, None, False, None))
paddle.fluid.layers.create_global_var ArgSpec(args=['shape', 'value', 'dtype', 'persistable', 'force_cpu', 'name'], varargs=None, keywords=None, defaults=(False, False, None))
paddle.fluid.layers.cast ArgSpec(args=['x', 'dtype'], varargs=None, keywords=None, defaults=None)
paddle.fluid.layers.tensor_array_to_tensor ArgSpec(args=['input', 'axis', 'name'], varargs=None, keywords=None, defaults=(1, None))
paddle.fluid.layers.concat ArgSpec(args=['input', 'axis', 'name'], varargs=None, keywords=None, defaults=(0, None))
paddle.fluid.layers.sums ArgSpec(args=['input', 'out'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.assign ArgSpec(args=['input', 'output'], varargs=None, keywords=None, defaults=(None,))
......
......@@ -317,6 +317,7 @@ op_library(save_op DEPS lod_tensor)
op_library(load_op DEPS lod_tensor)
op_library(save_combine_op DEPS lod_tensor)
op_library(load_combine_op DEPS lod_tensor)
op_library(tensor_array_to_tensor_op DEPS concat_op)
op_library(concat_op DEPS concat_and_split)
list(REMOVE_ITEM GENERAL_OPS ${DEPS_OPS})
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <string>
#include <vector>
#include "paddle/fluid/framework/lod_tensor_array.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/variable.h"
namespace paddle {
namespace operators {
using framework::Tensor;
void LodTensorArray2LodTensorVector(const framework::Scope &scope,
const std::string &base_name,
const std::string &lod_tensor_array_name,
std::vector<std::string> *res_names) {
auto &inx =
scope.FindVar(lod_tensor_array_name)->Get<framework::LoDTensorArray>();
for (size_t i = 0; i < inx.size(); i++) {
std::string var_name = base_name + std::to_string(i);
framework::Variable *g_feed_value =
const_cast<framework::Scope &>(scope).Var(var_name);
auto &feed_input =
*(g_feed_value->GetMutable<paddle::framework::LoDTensor>());
feed_input.ShareDataWith(inx[i]);
res_names->push_back(var_name);
}
}
void LodTensorVectorResizeFromLodTensorArray(
const framework::Scope &scope, const std::string &base_name,
const std::string &lod_tensor_array_name,
std::vector<std::string> *res_names) {
auto &inx =
scope.FindVar(lod_tensor_array_name)->Get<framework::LoDTensorArray>();
for (size_t i = 0; i < inx.size(); i++) {
std::string var_name = base_name + std::to_string(i);
framework::Variable *g_feed_value =
const_cast<framework::Scope &>(scope).Var(var_name);
auto &feed_input =
*(g_feed_value->GetMutable<paddle::framework::LoDTensor>());
auto dims = inx[i].dims();
feed_input.Resize(dims);
res_names->push_back(var_name);
}
}
void LodTensorArrayCreateFromLodTensorArray(
const framework::Scope &scope,
const std::string &input_lod_tensor_array_name,
const std::string &output_lod_tensor_array_name) {
auto &inx = scope.FindVar(input_lod_tensor_array_name)
->Get<framework::LoDTensorArray>();
auto &grad_inx = *scope.FindVar(output_lod_tensor_array_name)
->GetMutable<framework::LoDTensorArray>();
for (size_t i = 0; i < inx.size(); i++) {
std::string var_name = output_lod_tensor_array_name + std::to_string(i);
framework::Variable *g_feed_value =
const_cast<framework::Scope &>(scope).Var(var_name);
auto &feed_input =
*(g_feed_value->GetMutable<paddle::framework::LoDTensor>());
grad_inx.push_back(feed_input);
}
}
class LoDTensorArray2TensorOp : public framework::OperatorBase {
public:
using OperatorBase::OperatorBase;
private:
void RunImpl(const framework::Scope &scope,
const platform::Place &place) const override {
auto axis = Attr<int>("axis");
framework::AttributeMap attrs;
attrs["axis"] = axis;
auto &inx = scope.FindVar(Input("X"))->Get<framework::LoDTensorArray>();
auto &out =
*scope.FindVar(Output("Out"))->GetMutable<framework::LoDTensor>();
auto &out_inx =
*scope.FindVar(Output("OutIndex"))->GetMutable<framework::LoDTensor>();
const size_t n = inx.size();
PADDLE_ENFORCE_GT(n, 0, "Input tensorarray size should > 0.");
std::string base_name = Inputs("X")[0];
std::vector<std::string> names;
// get the input tensorarray items' dim in out_inx
auto out_inx_dim = out_inx.dims();
out_inx_dim[0] = inx.size();
out_inx.Resize(out_inx_dim);
std::string var_name = "out_index";
framework::Variable *tmp_index_var =
const_cast<framework::Scope &>(scope).Var(var_name);
auto &tmp_index_tensor =
*(tmp_index_var->GetMutable<paddle::framework::LoDTensor>());
tmp_index_tensor.Resize(out_inx_dim);
int *tmp_index_data =
tmp_index_tensor.mutable_data<int>(platform::CPUPlace());
auto out_dims = inx[0].dims();
size_t out_dim_sum = 0;
for (size_t index = 0; index < inx.size(); index++) {
auto inx_dims = inx[index].dims();
out_dim_sum += inx_dims[axis];
tmp_index_data[index] = inx_dims[axis];
}
out_inx.ShareDataWith(tmp_index_tensor);
// get input array items' dims
out_dims[axis] = out_dim_sum;
out.Resize(out_dims);
LodTensorArray2LodTensorVector(scope, base_name, Input("X"), &names);
// Invoke Reshape Op
auto concat_op = framework::OpRegistry::CreateOp(
"concat", {{"X", names}}, {{"Out", {Output("Out")}}}, attrs);
concat_op->Run(scope, place);
}
};
class LoDTensorArray2TensorOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "Input LoDTensorArray of tensor_array_to_tensor operator.");
AddOutput("Out", "Output tensor of tensor_array_to_tensor operator.");
AddOutput("OutIndex",
"Output input LoDTensorArray items' dims of "
"tensor_array_to_tensor operator.");
AddAttr<int>("axis",
"The axis along which the input tensors will be concatenated.")
.SetDefault(0);
AddComment(R"DOC(
tensor_array_to_tensor Operator.
Concatenate the input LoDTensorArray along dimension axis to the output Tensor.
Examples:
Input = {[1,2], [3,4], [5,6]}
axis = 0
Output = [[1,2],
[3,4],
[5,6]]
OutputIndex = [1,1,1]
)DOC");
}
};
class LoDTensorArray2TensorOpInferShape : public framework::InferShapeBase {
public:
void operator()(framework::InferShapeContext *ctx) const override {}
};
class LoDTensorArray2TensorGradInferShape : public framework::InferShapeBase {
public:
void operator()(framework::InferShapeContext *context) const override {}
};
class LoDTensorArray2TensorGradInferVarType
: public framework::VarTypeInference {
public:
void operator()(const framework::OpDesc &op_desc,
framework::BlockDesc *block) const override {
for (auto &out_var : op_desc.Output(framework::GradVarName("X"))) {
block->Var(out_var)->SetType(framework::proto::VarType::LOD_TENSOR_ARRAY);
}
}
};
class LoDTensorArray2TensorGradOp : public framework::OperatorBase {
public:
using OperatorBase::OperatorBase;
private:
void RunImpl(const framework::Scope &scope,
const platform::Place &place) const override {
auto axis = Attr<int>("axis");
framework::AttributeMap attrs;
attrs["axis"] = axis;
auto &inx = scope.FindVar(Input("X"))->Get<framework::LoDTensorArray>();
const size_t n = inx.size();
PADDLE_ENFORCE_GT(n, 0, "Input tensorarray size should > 0.");
std::string base_name = Inputs("X")[0];
std::vector<std::string> names;
LodTensorArray2LodTensorVector(scope, base_name, Input("X"), &names);
// grad
auto dx_name = Output(framework::GradVarName("X"));
auto dout_name = Input(framework::GradVarName("Out"));
std::vector<std::string> grad_names;
LodTensorVectorResizeFromLodTensorArray(scope, "grad_name", Input("X"),
&grad_names);
auto concat_grad_op = framework::OpRegistry::CreateOp(
"concat_grad", {{"X", names}, {"Out@GRAD", {dout_name}}},
{{"X@GRAD", grad_names}}, attrs);
concat_grad_op->Run(scope, place);
LodTensorArrayCreateFromLodTensorArray(scope, Input("X"), dx_name);
auto &grad_inx =
*scope.FindVar(dx_name)->GetMutable<framework::LoDTensorArray>();
for (size_t i = 0; i < grad_names.size(); i++) {
std::string var_name = grad_names[i];
auto &feed_input = scope.FindVar(var_name)->Get<framework::LoDTensor>();
grad_inx[i].ShareDataWith(feed_input);
}
}
};
} // namespace operators
} // namespace paddle
USE_OP(concat);
namespace ops = paddle::operators;
REGISTER_OPERATOR(tensor_array_to_tensor, ops::LoDTensorArray2TensorOp,
ops::LoDTensorArray2TensorOpMaker,
ops::LoDTensorArray2TensorOpInferShape,
paddle::framework::DefaultGradOpDescMaker<true>);
REGISTER_OPERATOR(tensor_array_to_tensor_grad, ops::LoDTensorArray2TensorGradOp,
ops::LoDTensorArray2TensorGradInferShape,
ops::LoDTensorArray2TensorGradInferVarType);
......@@ -24,10 +24,10 @@ from .layer_function_generator import templatedoc
import numpy
__all__ = [
'create_tensor', 'create_parameter', 'create_global_var', 'cast', 'concat',
'sums', 'assign', 'fill_constant_batch_size_like', 'fill_constant',
'argmin', 'argmax', 'argsort', 'ones', 'zeros', 'reverse', 'has_inf',
'has_nan', 'isfinite'
'create_tensor', 'create_parameter', 'create_global_var', 'cast',
'tensor_array_to_tensor', 'concat', 'sums', 'assign',
'fill_constant_batch_size_like', 'fill_constant', 'argmin', 'argmax',
'argsort', 'ones', 'zeros', 'reverse', 'has_inf', 'has_nan', 'isfinite'
]
......@@ -193,6 +193,60 @@ def concat(input, axis=0, name=None):
return out
def tensor_array_to_tensor(input, axis=1, name=None):
"""
This function concatenates the input LodTensorArray along the axis mentioned
and returns that as the output.
A simple example as below:
.. code-block:: text
Given:
input.data = {[[0.6, 0.1, 0.3],
[0.5, 0.3, 0.2]],
[[1.3],
[1.8]],
[[2.3, 2.1],
[2.5, 2.4]]}
axis = 1
Then:
output.data = [[0.6, 0.1, 0.3, 1.3, 2.3, 2.1],
[0.5, 0.3, 0.2, 1.8, 2.5, 2.4]]
output_index.data = [3, 1, 2]
Args:
input(list): Input LodTensorArray
axis(int): Integer axis along which the tensors will be concatenated
name(str|None): A name for this layer(optional). If set None, the layer
will be named automatically.
Returns:
Variable: Output variable of the concatenation
Variable: The input LodTensorArray items' dims along the axis
Examples:
.. code-block:: python
output, output_index = fluid.layers.tensor_array_to_tensor(input=tensor_array)
"""
helper = LayerHelper('tensor_array_concat', **locals())
out = helper.create_variable_for_type_inference(dtype=helper.input_dtype())
out_index = helper.create_variable_for_type_inference(dtype="int32")
helper.append_op(
type='tensor_array_concat',
inputs={'X': input},
outputs={'Out': [out],
'OutIndex': [out_index]},
attrs={'axis': axis})
return out, out_index
def sums(input, out=None):
"""
This function performs the sum operation on the input and returns the
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy
import paddle.fluid as fluid
import paddle.fluid.core as core
from paddle.fluid.op import Operator
from paddle.fluid.executor import Executor
class TestLoDTensorArrayConcat(unittest.TestCase):
def setUp(self):
self.op_type = "tensor_array_to_tensor"
self.attrs = {"axis": 0}
self.outputs = ["Out"]
def test_get_set(self):
scope = core.Scope()
program = fluid.Program()
block = program.global_block()
input_arr = block.create_var(
name="tmp_lod_tensor_array",
type=core.VarDesc.VarType.LOD_TENSOR_ARRAY)
input_arr.persistable = True
input_arr_var = scope.var('tmp_lod_tensor_array')
input_tensor_array = input_arr_var.get_lod_tensor_array()
self.assertEqual(0, len(input_tensor_array))
cpu = core.CPUPlace()
for i in range(10):
t = core.LoDTensor()
if i == 0:
t.set(numpy.array([[i], [i]], dtype='float32'), cpu)
else:
t.set(numpy.array([[i]], dtype='float32'), cpu)
input_tensor_array.append(t)
self.assertEqual(10, len(input_tensor_array))
random_grad = numpy.random.random_sample([11]).astype(numpy.float32)
y_out = block.create_var(name="Out")
y_out.persistable = True
y_out_index = block.create_var(name="OutIndex")
y_out_index.persistable = True
y_grad_arr = block.create_var(
name='Out@GRAD', dtype='float32', shape=[11])
y_grad_arr.persistable = True
y_grad = scope.var('Out@GRAD')
y_grad_tensor = y_grad.get_tensor()
y_grad_tensor.set(random_grad, cpu)
op = block.append_op(
type=self.op_type,
inputs={"X": input_arr},
outputs={"Out": y_out,
"OutIndex": y_out_index},
attrs=self.attrs)
out_grad = block.create_var(
name="tmp_lod_tensor_array@GRAD",
type=core.VarDesc.VarType.LOD_TENSOR_ARRAY)
out_grad.persistable = True
grad_op_desc_list, op_grad_to_var = core.get_grad_op_desc(op.desc,
set(), [])
grad_op_desc = grad_op_desc_list[0]
new_op_desc = block.desc.append_op()
new_op_desc.copy_from(grad_op_desc)
for var_name in grad_op_desc.output_arg_names():
block.desc.var(var_name.encode("ascii"))
grad_op_desc.infer_var_type(block.desc)
grad_op_desc.infer_shape(block.desc)
for arg in grad_op_desc.output_arg_names():
grad_var = block.desc.find_var(arg.encode("ascii"))
grad_var.set_dtype(core.VarDesc.VarType.FP32)
fetch_list = []
fetch_list.append(block.var('Out'))
fetch_list.append(block.var('OutIndex'))
exe = fluid.Executor(fluid.CPUPlace())
out = exe.run(program, fetch_list=fetch_list, scope=scope)
#print ("index: ", numpy.array(out[1]))
# test forward
tensor_res = numpy.array(out[0])
tensor_res_out_idx = numpy.array(out[1])
tensor_gt = numpy.array(
[0] + [0, 1, 2, 3, 4, 5, 6, 7, 8, 9], dtype='float32')
self.assertEqual(len(tensor_res), len(tensor_gt))
self.assertEqual(len(tensor_res_out_idx), 10)
for i in range(len(tensor_res)):
self.assertEqual(tensor_res[i], tensor_gt[i])
for i in range(len(tensor_res_out_idx)):
if i == 0:
self.assertEqual(tensor_res_out_idx[i], 2)
else:
self.assertEqual(tensor_res_out_idx[i], 1)
# test backward
grad_tensor = scope.var('tmp_lod_tensor_array@GRAD')
grad_tensor_array = grad_tensor.get_lod_tensor_array()
self.assertEqual(10, len(grad_tensor_array))
for i in range(len(grad_tensor_array)):
if i == 0:
self.assertEqual(
numpy.array(grad_tensor_array[i])[0],
numpy.array(random_grad[i]))
self.assertEqual(
numpy.array(grad_tensor_array[i])[1],
numpy.array(random_grad[i + 1]))
if i == 1:
self.assertEqual(
numpy.array(grad_tensor_array[i]),
numpy.array(random_grad[i + 1]))
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册