未验证 提交 c318aa5f 编写于 作者: G guochaorong 提交者: GitHub

Merge pull request #11850 from guochaorong/revert_11496

Revert "Extend fill_zeros_like_op for zero-filling an LoDTensorArray …
...@@ -748,10 +748,6 @@ proto::VarType::Type OperatorWithKernel::IndicateDataType( ...@@ -748,10 +748,6 @@ proto::VarType::Type OperatorWithKernel::IndicateDataType(
t = &var->Get<LoDTensor>(); t = &var->Get<LoDTensor>();
} else if (var->IsType<SelectedRows>()) { } else if (var->IsType<SelectedRows>()) {
t = &(var->Get<SelectedRows>().value()); t = &(var->Get<SelectedRows>().value());
} else if (var->IsType<LoDTensorArray>()) {
const LoDTensorArray& arr = var->Get<LoDTensorArray>();
PADDLE_ENFORCE(arr.size() > 0);
t = &(arr[0]);
} }
if (t != nullptr) { if (t != nullptr) {
int tmp = static_cast<int>(ToDataType(t->type())); int tmp = static_cast<int>(ToDataType(t->type()));
......
...@@ -26,12 +26,8 @@ class FillZerosLikeOp : public framework::OperatorWithKernel { ...@@ -26,12 +26,8 @@ class FillZerosLikeOp : public framework::OperatorWithKernel {
"Input(X) of FillZerosLikeOp should not be null."); "Input(X) of FillZerosLikeOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"), PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of FillZerosLikeOp should not be null."); "Output(Out) of FillZerosLikeOp should not be null.");
ctx->SetOutputDim("Out", ctx->GetInputDim("X"));
if (ctx->IsRuntime() && ctx->ShareLoD("X", /*->*/ "Out");
ctx->GetOutputsVarType("Out")[0] ==
framework::proto::VarType::LOD_TENSOR_ARRAY) {
return; // skip runtime infershape when is tensor array;
}
} }
}; };
...@@ -43,7 +39,7 @@ class FillZerosLikeOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -43,7 +39,7 @@ class FillZerosLikeOpMaker : public framework::OpProtoAndCheckerMaker {
AddComment(R"DOC( AddComment(R"DOC(
FillZerosLike Operator. FillZerosLike Operator.
Fill up a variable with zeros, supporting both LoDTensor and LoDTensorArray. Fill up a variable with zeros.
The output will have the same size as the input. The output will have the same size as the input.
)DOC"); )DOC");
......
...@@ -13,7 +13,6 @@ See the License for the specific language governing permissions and ...@@ -13,7 +13,6 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#pragma once #pragma once
#include "paddle/fluid/framework/lod_tensor_array.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/operators/math/math_function.h"
...@@ -24,30 +23,13 @@ template <typename DeviceContext, typename T> ...@@ -24,30 +23,13 @@ template <typename DeviceContext, typename T>
class FillZerosLikeKernel : public framework::OpKernel<T> { class FillZerosLikeKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
auto var = context.InputVar("X"); auto* out = context.Output<framework::Tensor>("Out");
if (var->IsType<framework::LoDTensor>()) { out->mutable_data<T>(context.GetPlace());
auto& input = *context.Input<framework::LoDTensor>("X");
auto& output = *context.Output<framework::LoDTensor>("Out");
output.Resize(input.dims());
output.set_lod(input.lod());
output.mutable_data<T>(context.GetPlace());
math::SetConstant<DeviceContext, T> setter;
setter(context.template device_context<DeviceContext>(), &(output),
static_cast<T>(0));
} else if (var->IsType<framework::LoDTensorArray>()) {
auto& input = *context.Input<framework::LoDTensorArray>("X");
auto& output = *context.Output<framework::LoDTensorArray>("Out");
output.resize(input.size());
for (auto i = 0; i < input.size(); i++) {
output[i].Resize(input[i].dims());
output[i].set_lod(input[i].lod());
output[i].mutable_data<T>(context.GetPlace());
math::SetConstant<DeviceContext, T> setter; math::SetConstant<DeviceContext, T> setter;
setter(context.template device_context<DeviceContext>(), &(output[i]), setter(context.template device_context<DeviceContext>(), out,
static_cast<T>(0)); static_cast<T>(0));
} }
}
}
}; };
} // namespace operators } // namespace operators
......
...@@ -95,7 +95,6 @@ __all__ = [ ...@@ -95,7 +95,6 @@ __all__ = [
'relu', 'relu',
'log', 'log',
'crop', 'crop',
'fill_zeros_like',
] ]
...@@ -5185,40 +5184,3 @@ def crop(x, shape=None, offsets=None, name=None): ...@@ -5185,40 +5184,3 @@ def crop(x, shape=None, offsets=None, name=None):
outputs={'Out': out}, outputs={'Out': out},
attrs=None if len(attrs) == 0 else attrs) attrs=None if len(attrs) == 0 else attrs)
return out return out
def fill_zeros_like(x):
"""
This layer takes an input and outputs a variable that has the same structure as
the input and with all the element values as zero. The variable can be a Tensor
or TensorArray.
.. code-block:: text
Given
X = [[0, 1, 2, 0],
[0, 3, 4, 0],
[0, 0, 0, 0]],
output is:
Out = [[0, 0, 0, 0],
[0, 0, 0, 0],
[0, 0, 0, 0]].
Args:
x (Variable): The input variable, which could be a tensor or tensor array
Returns:
Variable: The zero-filled variable, which has the same type and shape as
the input variable.
Examples:
.. code-block:: python
y = fluid.layers.fill_zeros_like(x)
"""
helper = LayerHelper('fill_zeros_like', **locals())
out = helper.create_tmp_variable(dtype=x.dtype)
helper.append_op(
type='fill_zeros_like', inputs={'X': [x]}, outputs={'Out': [out]})
return out
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import paddle.fluid.core as core
import numpy
import paddle.fluid.layers as layers
from paddle.fluid.framework import Program, program_guard
from paddle.fluid.executor import Executor
import paddle.fluid as fluid
import paddle.fluid.core as core
class TestFillZerosLikeOpForTensorArray(unittest.TestCase):
def place(self):
return core.CPUPlace()
def test_zero_filling_lod_tensor_array(self):
tensor = core.LoDTensor()
tensor.set(
numpy.arange(20).reshape(20, 1).astype('int32'), self.place())
tensor.set_lod([[0, 2, 5], [0, 3, 9, 11, 17, 20]])
expect = [
numpy.array(
[0, 0, 0, 0, 0], dtype='int32'), numpy.array(
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], dtype='int32'),
numpy.array(
[0, 0, 0], dtype='int32')
]
lod = [[[0, 2, 5]], [[0, 6, 12]], [[0, 3]]]
self.main(
tensor=tensor,
expect_array=expect,
expect_lod=lod,
expect_max_len=3)
def main(self, tensor, expect_array, expect_lod, expect_max_len, level=0):
place = self.place()
program = Program()
with program_guard(program):
x = layers.data(name='x', shape=[10])
x.persistable = True
table = layers.lod_rank_table(x, level=level)
max_len = layers.max_sequence_len(table)
max_len.persistable = True
array = layers.lod_tensor_to_array(x, table)
array = layers.fill_zeros_like(array)
array.persistable = True
result = layers.array_to_lod_tensor(array, table)
result.persistable = True
exe = Executor(place)
scope = core.Scope()
exe.run(program, feed={'x': tensor}, scope=scope)
var = scope.find_var(array.name)
array = var.get_lod_tensor_array()
if expect_array is not None and expect_lod is not None:
self.check_array_same(array, expect_array, expect_lod)
self.assertEqual(
numpy.array(scope.find_var(max_len.name).get_tensor())[0],
expect_max_len)
def check_array_same(self, array, expect_tensor, expect_lod):
self.assertEqual(len(expect_tensor), len(array))
for i, exp in enumerate(zip(expect_tensor, expect_lod)):
exp_tensor, exp_lod = exp
exp_tensor = numpy.expand_dims(exp_tensor, axis=1)
self.assertTrue(numpy.allclose(exp_tensor, numpy.array(array[i])))
self.assertEqual(exp_lod, array[i].lod())
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册