From 2ff18e537f0b31f34a6cd9c2c5518c639672f42b Mon Sep 17 00:00:00 2001 From: JesseyXujin Date: Mon, 14 Oct 2019 11:12:19 +0800 Subject: [PATCH] add expand_as op, test=develop (#20565) * add expand_as op, test=develop * add expand_as op,test=develop * add expand_as op,test=develop * add nn.py, test=develop * delele paddle_enforce, test=develop --- paddle/fluid/API.spec | 1 + paddle/fluid/operators/expand_as_op.cc | 127 ++++++++++++ paddle/fluid/operators/expand_as_op.cu | 22 +++ paddle/fluid/operators/expand_as_op.h | 185 ++++++++++++++++++ python/paddle/fluid/layers/nn.py | 71 +++++++ .../tests/unittests/test_expand_as_op.py | 130 ++++++++++++ 6 files changed, 536 insertions(+) mode change 100644 => 100755 paddle/fluid/API.spec create mode 100644 paddle/fluid/operators/expand_as_op.cc create mode 100755 paddle/fluid/operators/expand_as_op.cu create mode 100755 paddle/fluid/operators/expand_as_op.h create mode 100755 python/paddle/fluid/tests/unittests/test_expand_as_op.py diff --git a/paddle/fluid/API.spec b/paddle/fluid/API.spec old mode 100644 new mode 100755 index c5d9d6e4ba..d9e1b1200a --- a/paddle/fluid/API.spec +++ b/paddle/fluid/API.spec @@ -243,6 +243,7 @@ paddle.fluid.layers.sequence_enumerate (ArgSpec(args=['input', 'win_size', 'pad_ paddle.fluid.layers.unique (ArgSpec(args=['x', 'dtype'], varargs=None, keywords=None, defaults=('int32',)), ('document', 'cab0b06e5683875f12f0efc62fa230a9')) paddle.fluid.layers.unique_with_counts (ArgSpec(args=['x', 'dtype'], varargs=None, keywords=None, defaults=('int32',)), ('document', '4496682f302007019e458a2f30d8a7c3')) paddle.fluid.layers.expand (ArgSpec(args=['x', 'expand_times', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', 'e93a1b102ab64b247c1b774e60d4c0d0')) +paddle.fluid.layers.expand_as (ArgSpec(args=['x', 'target_tensor', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', 'ca6b29aa6987776628a0b33f6dcaaaa6')) paddle.fluid.layers.sequence_concat (ArgSpec(args=['input', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', 'f47f9d207ac60b6f294087bcb1b64ae8')) paddle.fluid.layers.scale (ArgSpec(args=['x', 'scale', 'bias', 'bias_after_scale', 'act', 'name'], varargs=None, keywords=None, defaults=(1.0, 0.0, True, None, None)), ('document', 'a33547d41970fa3c59e6b2f21fe5f76d')) paddle.fluid.layers.elementwise_add (ArgSpec(args=['x', 'y', 'axis', 'act', 'name'], varargs=None, keywords=None, defaults=(-1, None, None)), ('document', '0c9c260e7738165a099f6a76da0b7814')) diff --git a/paddle/fluid/operators/expand_as_op.cc b/paddle/fluid/operators/expand_as_op.cc new file mode 100644 index 0000000000..204a93df23 --- /dev/null +++ b/paddle/fluid/operators/expand_as_op.cc @@ -0,0 +1,127 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/expand_as_op.h" +#include +#include + +namespace paddle { +namespace operators { + +using framework::Tensor; + +class ExpandAsOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true); + PADDLE_ENFORCE_EQ(ctx->HasInput("target_tensor"), true); + PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true); + auto x_dims = ctx->GetInputDim("X"); + auto target_tensor_dims = ctx->GetInputDim("target_tensor"); + PADDLE_ENFORCE_EQ(static_cast(x_dims.size()), + target_tensor_dims.size(), + "The rank of input(target_tensor) must be equal " + "to the rank of Input(X)."); + PADDLE_ENFORCE_LE(x_dims.size(), 6, + "The rank of Input(X) must not be greater than 6."); + std::vector out_shape(x_dims.size()); + ctx->SetOutputDim("Out", framework::make_ddim(out_shape)); + } +}; + +class ExpandAsOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", + "(Tensor, default Tensor). A tensor with rank in [1, 6]." + "X is the input to be expanded."); + AddOutput("Out", + "(Tensor, default Tensor). A tensor with rank in [1, 6]." + "The rank of Output(Out) have the same with Input(X). " + "After expanding, size of each dimension of Output(Out) is equal " + "to size of the corresponding dimension of Input(X) multiplying " + "the corresponding value given by Attr(expand_times)."); + AddInput("target_tensor", "Expand tensor's shape for each dimension."); + AddComment(R"DOC( +Expand as operator tiles the input by given times number. You should set times +number for each dimension by providing tensor 'expend_tensor'. The rank of X +should be in [1, 6]. Please note that size of 'expend_tensor' must be the same +with X's rank. Following is a using case: +Input(X) is a 3-D tensor with shape [2, 3, 1]: + [ + [[1], [2], [3]], + [[4], [5], [6]] + ] +target_tensors'shape: [2, 6, 2] +Output(Out) is a 3-D tensor with shape [2, 6, 2]: + [ + [[1, 1], [2, 2], [3, 3], [1, 1], [2, 2], [3, 3]], + [[4, 4], [5, 5], [6, 6], [4, 4], [5, 5], [6, 6]] + ] +)DOC"); + } +}; + +class ExpandAsGradOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true); + PADDLE_ENFORCE_EQ(ctx->HasInput(framework::GradVarName("Out")), true); + + auto x_dims = ctx->GetInputDim("X"); + auto x_grad_name = framework::GradVarName("X"); + if (ctx->HasOutput(x_grad_name)) { + ctx->SetOutputDim(x_grad_name, x_dims); + } + } +}; + +class ExpandAsGradOpDescMaker : public framework::SingleGradOpDescMaker { + public: + using framework::SingleGradOpDescMaker::SingleGradOpDescMaker; + + protected: + std::unique_ptr Apply() const override { + std::unique_ptr op(new framework::OpDesc()); + op->SetType("expand_as_grad"); + op->SetInput("X", Input("X")); + op->SetInput("target_tensor", Input("target_tensor")); + op->SetInput(framework::GradVarName("Out"), OutputGrad("Out")); + op->SetOutput(framework::GradVarName("X"), InputGrad("X")); + op->SetAttrMap(Attrs()); + return op; + } +}; + +// DECLARE_NO_NEED_BUFFER_VARS_INFERENCE(ExpandGradNoNeedBufVarsInferer, "X"); + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OPERATOR(expand_as, ops::ExpandAsOp, ops::ExpandAsOpMaker, + ops::ExpandAsGradOpDescMaker); +REGISTER_OPERATOR(expand_as_grad, ops::ExpandAsGradOp); +REGISTER_OP_CPU_KERNEL( + expand_as, ops::ExpandAsKernel, + ops::ExpandAsKernel, + ops::ExpandAsKernel, + ops::ExpandAsKernel); +REGISTER_OP_CPU_KERNEL( + expand_as_grad, + ops::ExpandAsGradKernel, + ops::ExpandAsGradKernel); diff --git a/paddle/fluid/operators/expand_as_op.cu b/paddle/fluid/operators/expand_as_op.cu new file mode 100755 index 0000000000..d7c894d248 --- /dev/null +++ b/paddle/fluid/operators/expand_as_op.cu @@ -0,0 +1,22 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ +#include "paddle/fluid/operators/expand_as_op.h" + +namespace ops = paddle::operators; +REGISTER_OP_CUDA_KERNEL( + expand_as, ops::ExpandAsKernel, + ops::ExpandAsKernel, + ops::ExpandAsKernel, + ops::ExpandAsKernel); +REGISTER_OP_CUDA_KERNEL( + expand_as_grad, + ops::ExpandAsGradKernel, + ops::ExpandAsGradKernel); diff --git a/paddle/fluid/operators/expand_as_op.h b/paddle/fluid/operators/expand_as_op.h new file mode 100755 index 0000000000..249f4c35a7 --- /dev/null +++ b/paddle/fluid/operators/expand_as_op.h @@ -0,0 +1,185 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include + +#include +#include +#include +#include +#include +#include +#include "paddle/fluid/framework/eigen.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/operator.h" + +#define MAX_RANK_SUPPORTED 6 + +#define EXPAND_AS_TEMPLATE(z, n, data) \ + case n + 1: { \ + ExpandAs(context); \ + break; \ + } +#define REP_EXPAND_AS_TEMPLATE(n) BOOST_PP_REPEAT(n, EXPAND_AS_TEMPLATE, ~) +#define COND(n) \ + BOOST_PP_GREATER_EQUAL(BOOST_PP_DIV(n, MAX_RANK_SUPPORTED), \ + BOOST_PP_MOD(n, MAX_RANK_SUPPORTED)) +#define EXPAND_AS_GRAD_CASE(n) \ + case n: { \ + ExpandAsBackward(context, reshape_dims_vec, reduce_dims_vec); \ + break; \ + } +#define EXPAND_AS_GRAD_TEMPLATE(z, n, data) \ + BOOST_PP_IF(COND(n), EXPAND_AS_GRAD_CASE(n), ) +#define REP_EXPAND_AS_GRAD_TEMPLATE(n) \ + BOOST_PP_REPEAT(n, EXPAND_AS_GRAD_TEMPLATE, ~) + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; +template +using EigenVector = framework::EigenVector; +template +using EigenTensor = framework::EigenTensor; + +template +class ExpandAsKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto rank = context.Input("X")->dims().size(); + switch (rank) { + REP_EXPAND_AS_TEMPLATE(MAX_RANK_SUPPORTED) + default: + PADDLE_THROW("Only support tensor with rank being between 1 and 6."); + } + } + + protected: + template + void ExpandAs(const framework::ExecutionContext& context) const { + auto* in0 = context.Input("X"); + auto in_dims = in0->dims(); + auto* target_tensor = context.Input("target_tensor"); + auto* out0 = context.Output("Out"); + Eigen::DSizes bcast_dims; + int bcast_dims_remainder = 0; + auto x_dims = in0->dims(); + auto y_dims = target_tensor->dims(); + for (int i = 0; i < y_dims.size(); ++i) { + PADDLE_ENFORCE_NE(x_dims[i], 0, "X(input) should not have 0 dim"); + bcast_dims[i] = y_dims[i] / x_dims[i]; + bcast_dims_remainder += y_dims[i] % x_dims[i]; + } + PADDLE_ENFORCE_EQ(bcast_dims_remainder, 0, + "X(input) could not be broadcast together with remapped " + "shape(expand tensor's shape)"); + framework::DDim out_dims(in_dims); + for (size_t i = 0; i < bcast_dims.size(); ++i) { + out_dims[i] *= bcast_dims[i]; + } + + out0->Resize(out_dims); + auto x = EigenTensor::From(*in0); + out0->mutable_data(context.GetPlace()); + auto y = EigenTensor::From(*out0); + auto& place = + *context.template device_context().eigen_device(); + y.device(place) = x.broadcast(bcast_dims); + } +}; + +template +class ExpandAsGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto* in0 = context.Input("X"); + auto* target_tensor = context.Input("target_tensor"); + auto x_dims = in0->dims(); + auto y_dims = target_tensor->dims(); + std::vector bcast_dims; + for (int i = 0; i < y_dims.size(); ++i) { + bcast_dims.push_back(y_dims[i] / x_dims[i]); + } + std::vector reshape_dims_vec; + std::vector reduce_dims_vec; + for (size_t i = 0; i < bcast_dims.size(); ++i) { + if (bcast_dims[i] == 1) { + reshape_dims_vec.push_back(x_dims[i]); + } else { + if (x_dims[i] == 1) { + reduce_dims_vec.push_back(reshape_dims_vec.size()); + reshape_dims_vec.push_back(bcast_dims[i]); + } else { + reduce_dims_vec.push_back(reshape_dims_vec.size()); + reshape_dims_vec.push_back(bcast_dims[i]); + reshape_dims_vec.push_back(x_dims[i]); + } + } + } + int dims = reshape_dims_vec.size() * MAX_RANK_SUPPORTED + + reduce_dims_vec.size() - MAX_RANK_SUPPORTED - 1; + // no need reduce, just copy + if (reduce_dims_vec.size() == 0) { + auto* in0 = context.Input(framework::GradVarName("Out")); + auto* out0 = context.Output(framework::GradVarName("X")); + out0->mutable_data(context.GetPlace()); + framework::TensorCopy(*in0, context.GetPlace(), context.device_context(), + out0); + } else { + switch (dims) { + REP_EXPAND_AS_GRAD_TEMPLATE(72) + default: + PADDLE_THROW("Only support tensor with rank being between 1 and 6."); + } + } + } + + protected: + template + void ExpandAsBackward(const framework::ExecutionContext& context, + const std::vector& reshape_dims_vec, + const std::vector& reduce_dims_vec) const { + size_t reshape_size = Dims / MAX_RANK_SUPPORTED + 1; + size_t reduce_size = Dims % MAX_RANK_SUPPORTED + 1; + PADDLE_ENFORCE_EQ(reshape_size, reshape_dims_vec.size(), + "Inconsistent size between template Dims and " + "reshape dimensions."); + PADDLE_ENFORCE_EQ(reduce_size, reduce_dims_vec.size(), + "Inconsistent size between template Dims and " + "reduce dimensions."); + auto* in0 = context.Input(framework::GradVarName("Out")); + auto* out0 = context.Output(framework::GradVarName("X")); + out0->mutable_data(context.GetPlace()); + auto x_grad = EigenVector::Flatten(*out0); + Eigen::DSizes reshape_dims; + for (size_t i = 0; i < reshape_size; ++i) { + reshape_dims[i] = reshape_dims_vec[i]; + } + Eigen::DSizes reduce_dims; + for (size_t i = 0; i < reduce_size; ++i) { + reduce_dims[i] = reduce_dims_vec[i]; + } + auto out_grad = EigenVector::Flatten(*in0); + x_grad.device( + *context.template device_context().eigen_device()) = + out_grad.reshape(reshape_dims) + .sum(reduce_dims) + .reshape(x_grad.dimensions()); + } +}; + +} // namespace operators +} // namespace paddle diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 93bc589b51..aef079e367 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -157,6 +157,7 @@ __all__ = [ 'unique', 'unique_with_counts', 'expand', + 'expand_as', 'sequence_concat', 'scale', 'elementwise_add', @@ -12864,6 +12865,76 @@ def expand(x, expand_times, name=None): return out +def expand_as(x, target_tensor, name=None): + """ + expand_as operator tiles to the input by given expand tensor. You should set expand tensor + for each dimension by providing tensor 'target_tensor'. The rank of X + should be in [1, 6]. Please note that size of 'target_tensor' must be the same + with X's rank. Following is a using case: + + + .. code-block:: text + + Input(X) is a 3-D tensor with shape [2, 3, 1]: + + [ + [[1], [2], [3]], + [[4], [5], [6]] + ] + + target_tensor's shape: [2, 6, 2] + + Output(Out) is a 3-D tensor with shape [2, 6, 2]: + + [ + [[1, 1], [2, 2], [3, 3], [1, 1], [2, 2], [3, 3]], + [[4, 4], [5, 5], [6, 6], [4, 4], [5, 5], [6, 6]] + ] + + + Args: + x (Variable): A Tensor with dtype float64, float32, int32. + A tensor with rank in [1, 6]. + target_tensor (Variable): A Tensor with dtype float64, float32, int32. + target_tensor for expanding to Input(X). Only use target_tensor'shape. + + Returns: + Variable: A Tensor with dtype float64, float32, int32. + After expanding, size of each dimension of Output(Out) is equal to the size + of the corresponding dimension of target_tensor multiplying the corresponding + value given by target_tensor. + + + Examples: + .. code-block:: python + + import paddle.fluid as fluid + import numpy as np + + data = fluid.layers.data(name="data", shape=[-1,10], dtype='float64') + target_tensor = fluid.layers.data( + name="target_tensor", shape=[-1,20], dtype='float64') + result = fluid.layers.expand_as(x=data, target_tensor=target_tensor) + use_cuda = False + place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace() + exe = fluid.Executor(place) + exe.run(fluid.default_startup_program()) + x = np.random.rand(3,10) + y = np.random.rand(3,20) + output= exe.run(feed={"data":x,"target_tensor":y},fetch_list=[result.name]) + print(output[0].shape) + #(3,20) + + """ + + helper = LayerHelper('expand_as', input=x, **locals()) + dtype = helper.input_dtype(input_param_name='x') + out = helper.create_variable_for_type_inference(dtype) + inputs = {'X': x, 'target_tensor': target_tensor} + helper.append_op(type='expand_as', inputs=inputs, outputs={'Out': out}) + return out + + from paddle.fluid.framework import convert_np_dtype_to_dtype_ diff --git a/python/paddle/fluid/tests/unittests/test_expand_as_op.py b/python/paddle/fluid/tests/unittests/test_expand_as_op.py new file mode 100755 index 0000000000..836e49c54e --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_expand_as_op.py @@ -0,0 +1,130 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +from op_test import OpTest +import paddle.fluid as fluid + + +def bcast(x, target_tensor): + x_dims = x.shape + y_dims = target_tensor.shape + bcast_dims = [] + for i in range(len(x_dims)): + bcast_dims.append(int(y_dims[i] / x_dims[i])) + bcast_dims = np.array(bcast_dims).astype("int64") + return bcast_dims + + +class TestExpandAsOpRank1(OpTest): + def setUp(self): + self.op_type = "expand_as" + x = np.random.rand(12).astype("float64") + target_tensor = np.random.rand(24).astype("float64") + self.inputs = {'X': x, 'target_tensor': target_tensor} + self.attrs = {} + bcast_dims = bcast(x, target_tensor) + output = np.tile(self.inputs['X'], bcast_dims) + self.outputs = {'Out': output} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(['X'], 'Out') + + +class TestExpandAsOpRank2(OpTest): + def setUp(self): + self.op_type = "expand_as" + x = np.random.rand(2, 3).astype("float64") + target_tensor = np.random.rand(4, 6).astype("float64") + self.inputs = {'X': x, 'target_tensor': target_tensor} + self.attrs = {} + bcast_dims = bcast(x, target_tensor) + output = np.tile(self.inputs['X'], bcast_dims) + self.outputs = {'Out': output} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(['X'], 'Out') + + +class TestExpandAsOpRank3(OpTest): + def setUp(self): + self.op_type = "expand_as" + x = np.random.rand(2, 3, 3).astype("float64") + target_tensor = np.random.rand(4, 6, 6).astype("float64") + self.inputs = {'X': x, 'target_tensor': target_tensor} + self.attrs = {} + bcast_dims = bcast(x, target_tensor) + output = np.tile(self.inputs['X'], bcast_dims) + self.outputs = {'Out': output} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(['X'], 'Out') + + +class TestExpandAsOpRank4(OpTest): + def setUp(self): + self.op_type = "expand_as" + x = np.random.rand(1, 1, 3, 16).astype("float64") + target_tensor = np.random.rand(4, 6, 6, 32).astype("float64") + self.inputs = {'X': x, 'target_tensor': target_tensor} + self.attrs = {} + bcast_dims = bcast(x, target_tensor) + output = np.tile(self.inputs['X'], bcast_dims) + self.outputs = {'Out': output} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(['X'], 'Out') + + +# Test python API +class TestExpandAPI(OpTest): + def test_api(self): + input1 = np.random.random([12, 14]).astype("float32") + input2 = np.random.random([48, 14]).astype("float32") + x = fluid.layers.data( + name='x', shape=[12, 14], append_batch_size=False, dtype="float32") + + y = fluid.layers.data( + name='target_tensor', + shape=[48, 14], + append_batch_size=False, + dtype="float32") + + out_1 = fluid.layers.expand_as(x, target_tensor=y) + + exe = fluid.Executor(place=fluid.CPUPlace()) + res_1 = exe.run(fluid.default_main_program(), + feed={"x": input1, + "target_tensor": input2}, + fetch_list=[out_1]) + assert np.array_equal(res_1[0], np.tile(input1, (4, 1))) + + +if __name__ == "__main__": + unittest.main() -- GitLab