diff --git a/paddle/fluid/operators/CMakeLists.txt b/paddle/fluid/operators/CMakeLists.txt index 391799e700034413e28205adc0df8fce46c4f76a..236fa84f1bbb800b69e8cb4c0b503eeecc943f3b 100644 --- a/paddle/fluid/operators/CMakeLists.txt +++ b/paddle/fluid/operators/CMakeLists.txt @@ -107,6 +107,7 @@ set(COMMON_OP_DEPS ${COMMON_OP_DEPS} layer) set(OPERATOR_DEPS ${OPERATOR_DEPS} ${COMMON_OP_DEPS}) set(GLOB_OPERATOR_DEPS ${OPERATOR_DEPS} CACHE INTERNAL "Global Op dependencies") +cc_test(assign_op_test SRCS assign_op_test.cc DEPS assign_op) cc_test(gather_test SRCS gather_test.cc DEPS tensor) cc_test(scatter_test SRCS scatter_test.cc DEPS tensor math_function) cc_test(beam_search_decode_op_test SRCS beam_search_decode_op_test.cc DEPS lod_tensor) diff --git a/paddle/fluid/operators/assign_op.cc b/paddle/fluid/operators/assign_op.cc index 315982ec8fb49765db8a7a21a0810cbb6b58339f..0a89b2e416b09475b3f19785306199e554b9bf1d 100644 --- a/paddle/fluid/operators/assign_op.cc +++ b/paddle/fluid/operators/assign_op.cc @@ -12,59 +12,13 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/framework/data_type.h" -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/framework/var_type.h" -#include "paddle/fluid/platform/device_context.h" +#include "paddle/fluid/operators/assign_op.h" + +#include +#include namespace paddle { namespace operators { -class AssignFunctor { - public: - AssignFunctor(framework::Variable *out, - const platform::DeviceContext &dev_ctx) - : out_(out), dev_ctx_(dev_ctx) {} - - void operator()(const framework::LoDTensor &lod_tensor) const { - auto &out_tensor = *out_->GetMutable(); - copy_tensor(lod_tensor, &out_tensor); - } - - void operator()(const framework::LoDTensorArray &array) const { - auto &out_array = *out_->GetMutable(); - out_array.resize(array.size()); - for (size_t i = 0; i < array.size(); ++i) { - copy_tensor(array[i], &out_array[i]); - } - } - - void operator()(const framework::SelectedRows &rows) const { - framework::SelectedRows &out_rows = - *out_->GetMutable(); - out_rows.set_rows(rows.rows()); - out_rows.set_height(rows.height()); - auto &t = rows.value(); - auto *m = out_rows.mutable_value(); - framework::TensorCopy(t, t.place(), dev_ctx_, m); - } - - template - void operator()(const T &v) const { - PADDLE_THROW("Not support type for assign op %s", typeid(T).name()); - } - - private: - void copy_tensor(const framework::LoDTensor &lod_tensor, - framework::LoDTensor *out) const { - if (lod_tensor.numel() == 0) return; - auto &out_tensor = *out; - TensorCopy(lod_tensor, lod_tensor.place(), dev_ctx_, &out_tensor); - out_tensor.set_lod(lod_tensor.lod()); - } - - framework::Variable *out_; - const platform::DeviceContext &dev_ctx_; -}; class AssignOp : public framework::OperatorWithKernel { public: diff --git a/paddle/fluid/operators/assign_op.h b/paddle/fluid/operators/assign_op.h new file mode 100644 index 0000000000000000000000000000000000000000..6718999d7f70d1182de958c1b7f574284c7b449f --- /dev/null +++ b/paddle/fluid/operators/assign_op.h @@ -0,0 +1,72 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/fluid/framework/data_type.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/var_type.h" +#include "paddle/fluid/platform/device_context.h" + +namespace paddle { +namespace operators { +class AssignFunctor { + public: + AssignFunctor(framework::Variable *out, + const platform::DeviceContext &dev_ctx) + : out_(out), dev_ctx_(dev_ctx) {} + + void operator()(const framework::LoDTensor &lod_tensor) const { + auto &out_tensor = *out_->GetMutable(); + copy_tensor(lod_tensor, &out_tensor); + } + + void operator()(const framework::LoDTensorArray &array) const { + auto &out_array = *out_->GetMutable(); + out_array.resize(array.size()); + for (size_t i = 0; i < array.size(); ++i) { + copy_tensor(array[i], &out_array[i]); + } + } + + void operator()(const framework::SelectedRows &rows) const { + framework::SelectedRows &out_rows = + *out_->GetMutable(); + out_rows.set_rows(rows.rows()); + out_rows.set_height(rows.height()); + auto &t = rows.value(); + auto *m = out_rows.mutable_value(); + framework::TensorCopy(t, t.place(), dev_ctx_, m); + } + + template + void operator()(const T &v) const { + PADDLE_THROW("Not support type for assign op %s", typeid(T).name()); + } + + private: + void copy_tensor(const framework::LoDTensor &lod_tensor, + framework::LoDTensor *out) const { + if (lod_tensor.numel() == 0) return; + auto &out_tensor = *out; + TensorCopy(lod_tensor, lod_tensor.place(), dev_ctx_, &out_tensor); + out_tensor.set_lod(lod_tensor.lod()); + } + + framework::Variable *out_; + const platform::DeviceContext &dev_ctx_; +}; + +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/assign_op_test.cc b/paddle/fluid/operators/assign_op_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..58f360ad6059e4fe9743439cd0a16b1d7b9e241f --- /dev/null +++ b/paddle/fluid/operators/assign_op_test.cc @@ -0,0 +1,118 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ +#include "paddle/fluid/operators/assign_op.h" + +#include +#include +#include +#include + +#include "paddle/fluid/framework/ddim.h" +#include "paddle/fluid/framework/lod_tensor.h" +#include "paddle/fluid/framework/variable.h" +#include "paddle/fluid/platform/place.h" + +TEST(AssignOp, AssignLoDTensor) { + paddle::platform::CPUPlace cpu_place; + paddle::platform::CPUDeviceContext ctx(cpu_place); + + paddle::framework::Variable output; + paddle::operators::AssignFunctor assign_functor(&output, ctx); + + paddle::framework::LoDTensor input; + paddle::framework::DDim in_dims = paddle::framework::make_ddim({3, 4}); + int* in_data = input.mutable_data(in_dims, cpu_place); + for (int i = 0; i < 12; ++i) { + in_data[i] = i; + } + + assign_functor(input); + + auto& out_tensor = output.Get(); + paddle::framework::DDim out_dims = out_tensor.dims(); + EXPECT_EQ(in_dims, out_dims); + auto* out_data = out_tensor.data(); + for (int i = 0; i < 12; ++i) { + EXPECT_EQ(i, out_data[i]); + } +} + +TEST(AssignOp, AssignLoDTensorArray) { + paddle::platform::CPUPlace cpu_place; + paddle::platform::CPUDeviceContext ctx(cpu_place); + + paddle::framework::Variable output; + paddle::operators::AssignFunctor assign_functor(&output, ctx); + + paddle::framework::LoDTensorArray input; + for (int i = 0; i < 5; ++i) { + paddle::framework::DDim in_dims = + paddle::framework::make_ddim({i + 1, i + 2}); + paddle::framework::LoDTensor lod_tensor; + float* in_data = lod_tensor.mutable_data(in_dims, cpu_place); + for (int j = 0; j < (i + 1) * (i + 2); ++j) { + in_data[j] = static_cast(j); + } + input.push_back(lod_tensor); + } + + assign_functor(input); + + auto& out_array = output.Get(); + for (int i = 0; i < 5; ++i) { + paddle::framework::DDim out_dims = out_array[i].dims(); + EXPECT_EQ(paddle::framework::make_ddim({i + 1, i + 2}), out_dims); + const float* out_data = out_array[i].data(); + for (int j = 0; j < (i + 1) * (i + 2); ++j) { + EXPECT_EQ(static_cast(j), out_data[j]); + } + } +} + +TEST(AssignOp, AssignSelectedRows) { + paddle::platform::CPUPlace cpu_place; + paddle::platform::CPUDeviceContext ctx(cpu_place); + + paddle::framework::Variable output; + paddle::operators::AssignFunctor assign_functor(&output, ctx); + + std::vector rows{0, 4, 7}; + int64_t height = 10; + + paddle::framework::SelectedRows input(rows, height); + paddle::framework::Tensor* input_tensor = input.mutable_value(); + + paddle::framework::DDim in_dims = paddle::framework::make_ddim({3, 4}); + int* in_data = input_tensor->mutable_data(in_dims, cpu_place); + for (int i = 0; i < 12; ++i) { + in_data[i] = i; + } + + assign_functor(input); + + auto& out_selected_row = output.Get(); + const paddle::framework::Vector& out_rows = out_selected_row.rows(); + EXPECT_EQ(rows.size(), out_rows.size()); + for (size_t i = 0; i < rows.size(); ++i) { + EXPECT_EQ(rows[i], out_rows[i]); + } + EXPECT_EQ(height, out_selected_row.height()); + const paddle::framework::Tensor& out_tensor = out_selected_row.value(); + paddle::framework::DDim out_dims = out_tensor.dims(); + EXPECT_EQ(in_dims, out_dims); + auto* out_data = out_tensor.data(); + for (int i = 0; i < 12; ++i) { + EXPECT_EQ(i, out_data[i]); + } +} diff --git a/paddle/fluid/operators/select_input_op.cc b/paddle/fluid/operators/select_input_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..33a5ff99a5d984d2327aac6b05421891f6c05e14 --- /dev/null +++ b/paddle/fluid/operators/select_input_op.cc @@ -0,0 +1,114 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/memory/memcpy.h" +#include "paddle/fluid/operators/assign_op.h" +#include "paddle/fluid/operators/select_op_helper.h" + +namespace paddle { +namespace operators { + +// SelectInputOp takes multiple inputs and uses an integer mask to select +// one input to output. It is used in control flow. +class SelectInputOp : public framework::OperatorBase { + public: + SelectInputOp(const std::string &type, + const framework::VariableNameMap &inputs, + const framework::VariableNameMap &outputs, + const framework::AttributeMap &attrs) + : OperatorBase(type, inputs, outputs, attrs) {} + + private: + void RunImpl(const framework::Scope &scope, + const platform::Place &dev_place) const override { + platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); + auto &dev_ctx = *pool.Get(dev_place); + + auto &mask = scope.FindVar(Input("Mask"))->Get(); + size_t output_branch = static_cast(GetBranchNumber(mask)); + + const std::vector &x_names = Inputs("X"); + PADDLE_ENFORCE_LT(output_branch, x_names.size(), + "Selected branch number is greater than actual branch " + "num in SelectInputOp"); + + const framework::Variable *selected_x = + scope.FindVar(x_names[output_branch]); + framework::Variable *out = scope.FindVar(Output("Out")); + framework::VisitVarType(*selected_x, AssignFunctor(out, dev_ctx)); + } +}; + +class SelectInputOpProtoMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", + "The input LoDTensors or LoDTensorArray or SelectedRows. All " + "inputs must have same variable type") + .AsDuplicable(); + AddInput("Mask", + "A integer tensor with numel 1 specifying which input to output"); + AddOutput( + "Out", + "The merged output. The variable type of output must be same as X"); + // TODO(huihuangzheng): decide whether to add support for lod level + // Because this op is blocking whole control flow. I am implementing MVP + // (minimal viable product) here. + AddComment(R"DOC( +Merge branches of LoDTensor into a single Output with a mask interger +specifying the output branchi. +)DOC"); + } +}; + +class SelectInputInferShape : public framework::InferShapeBase { + public: + void operator()(framework::InferShapeContext *context) const override { + PADDLE_ENFORCE_EQ(context->HasInputs("X"), true, + "SelectInputOp must have input X."); + PADDLE_ENFORCE_EQ(context->HasInput("Mask"), true, + "SelectInputOp must have input Mask."); + PADDLE_ENFORCE_EQ(context->HasOutput("Out"), true, + "SelectInputOp must have output Out."); + } +}; + +template +class SelectInputGradMaker : public framework::SingleGradOpMaker { + public: + using framework::SingleGradOpMaker::SingleGradOpMaker; + + protected: + std::unique_ptr Apply() const override { + auto *grad_op = new T(); + grad_op->SetType("select_output"); + grad_op->SetInput("X", this->OutputGrad("Out")); + grad_op->SetInput("Mask", this->Input("Mask")); + grad_op->SetOutput("Out", + this->InputGrad("X", /* drop_empty_grad */ false)); + grad_op->SetAttrMap(this->Attrs()); + return std::unique_ptr(grad_op); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; + +REGISTER_OPERATOR(select_input, ops::SelectInputOp, + ops::SelectInputOpProtoMaker, ops::SelectInputInferShape, + ops::SelectInputGradMaker, + ops::SelectInputGradMaker); diff --git a/paddle/fluid/operators/select_op_helper.h b/paddle/fluid/operators/select_op_helper.h new file mode 100644 index 0000000000000000000000000000000000000000..a159530d2a345d0f68342dc57c882fcbb843b318 --- /dev/null +++ b/paddle/fluid/operators/select_op_helper.h @@ -0,0 +1,47 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include + +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/platform/device_context.h" + +// Functions used in SelectInputOp and SelectOutputOp +namespace paddle { +namespace operators { + +// Returns the integer in mask whose numel must be 1. The integer means the +// selected branch number. +inline int GetBranchNumber(const framework::LoDTensor &mask) { + PADDLE_ENFORCE_EQ(mask.numel(), 1, + "Mask in SelectOutputOp must have numel 1."); + if (platform::is_cpu_place(mask.place())) { + return mask.data()[0]; + } + // when platform::is_gpu_place(mask.place()) is ture + std::unique_ptr cpu_mask{new framework::LoDTensor()}; +#ifdef PADDLE_WITH_CUDA + framework::TensorCopySync(mask, platform::CPUPlace(), cpu_mask.get()); +#else + PADDLE_THROW( + "This version of PaddlePaddle doen NOT support GPU but got GPU tensor " + "Mask in SelectOutputOp. Please compile WITH_GPU option"); +#endif + return cpu_mask->data()[0]; +} + +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/select_output_op.cc b/paddle/fluid/operators/select_output_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..ac4bac4366a64f02b0d042db06f60dce4c94515d --- /dev/null +++ b/paddle/fluid/operators/select_output_op.cc @@ -0,0 +1,110 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/memory/memcpy.h" +#include "paddle/fluid/operators/assign_op.h" +#include "paddle/fluid/operators/select_op_helper.h" +#include "paddle/fluid/platform/device_context.h" + +namespace paddle { +namespace operators { + +// SelectOutputOp has one input, one integer mask and multiple outputs. It +// selects one output specified by the mask and copy the input to it. +class SelectOutputOp : public framework::OperatorBase { + public: + SelectOutputOp(const std::string &type, + const framework::VariableNameMap &inputs, + const framework::VariableNameMap &outputs, + const framework::AttributeMap &attrs) + : OperatorBase(type, inputs, outputs, attrs) {} + + private: + void RunImpl(const framework::Scope &scope, + const platform::Place &dev_place) const override { + platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); + auto &dev_ctx = *pool.Get(dev_place); + + auto &mask = scope.FindVar(Input("Mask"))->Get(); + size_t output_branch = static_cast(GetBranchNumber(mask)); + + const std::vector &out_names = Outputs("Out"); + PADDLE_ENFORCE_LT(output_branch, out_names.size(), + "Selected branch number is greater than actual branch " + "num in SelectOutputOp"); + + const framework::Variable *x = scope.FindVar(Input("X")); + framework::Variable *selected_out = scope.FindVar(out_names[output_branch]); + framework::VisitVarType(*x, AssignFunctor(selected_out, dev_ctx)); + } +}; + +class SelectOutputOpProtoMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", "The input LoDTensor or LoDTensorArray or SelectedRows."); + AddInput("Mask", "Tensor with numel 1 specifying which branch to output"); + AddOutput("Out", + "The output can contains multiple variables. The output of " + "selected branch will be same as input. We do nothing for " + "variables in other branch") + .AsDuplicable(); + // TODO(huihuangzheng): decide whether to add support for lod level + // Because this op is blocking whole control flow. I am implementing MVP + // (minimal viable product) here. + AddComment(R"DOC( +Split input variable into one output branch. The mask is an integer tensor to +specify which output branch should copy the input. +)DOC"); + } +}; + +class SelectOutputInferShape : public framework::InferShapeBase { + public: + void operator()(framework::InferShapeContext *context) const override { + PADDLE_ENFORCE_EQ(context->HasInput("X"), true, + "SelectOutputOp must have input X."); + PADDLE_ENFORCE_EQ(context->HasInput("Mask"), true, + "SelectOutputOp must have input Mask."); + PADDLE_ENFORCE_EQ(context->HasOutputs("Out"), true, + "SelectOutputOp must have output Out."); + } +}; + +template +class SelectOutputGradMaker : public framework::SingleGradOpMaker { + public: + using framework::SingleGradOpMaker::SingleGradOpMaker; + + protected: + std::unique_ptr Apply() const override { + auto *grad_op = new T(); + grad_op->SetType("select_input"); + grad_op->SetInput("Mask", this->Input("Mask")); + grad_op->SetInput("X", this->OutputGrad("Out")); + grad_op->SetOutput("Out", this->InputGrad("X")); + grad_op->SetAttrMap(this->Attrs()); + return std::unique_ptr(grad_op); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OPERATOR(select_output, ops::SelectOutputOp, + ops::SelectOutputOpProtoMaker, ops::SelectOutputInferShape, + ops::SelectOutputGradMaker, + ops::SelectOutputGradMaker); diff --git a/python/paddle/fluid/layers/control_flow.py b/python/paddle/fluid/layers/control_flow.py index 01554a43e941bb5017d7f49da94fd9fc55b03652..a37e83d2a34ab8b001b50a6e99724dcd723f73f1 100755 --- a/python/paddle/fluid/layers/control_flow.py +++ b/python/paddle/fluid/layers/control_flow.py @@ -35,6 +35,60 @@ __all__ = [ ] +def select_output(input, outputs, mask): + """ + **select_output** + This API takes in one input and multiple outputs and an integer mask. It + selects the output specified by the mask and copy the input to selected + output. It is useful in control flow. + + Args: + input(Variable): The input variable + outputs(tuple|list): The output variables + mask(Variable): A tensor containing 1 integer number selecting which + output to be copied with input + + Returns: + Variable: The outputs variables + """ + helper = LayerHelper('select_output', **locals()) + helper.append_op( + type='select_output', + inputs={'X': input, + 'Mask': mask}, + outputs={'Out': outputs}) + return outputs + + +def select_input(inputs, mask): + """ + **select_input** + + This API takes in multiple inputs and uses an integer mask to select one + input to output. It is useful in control flow. + + Args: + inputs(tuple|list): The input variables + mask(Variable): A tensor containing 1 integer number selecting which + input to output + + Returns: + Variable: The selected input variable + """ + helper = LayerHelper('select_input', **locals()) + if isinstance(inputs, list) or isinstance(inputs, tuple): + input_dtype = inputs[0].dtype + else: + input_dtype = inputs.dtype + out = helper.create_variable(dtype=input_dtype) + helper.append_op( + type='select_input', + inputs={'X': inputs, + 'Mask': mask}, + outputs={'Out': out}) + return out + + def split_lod_tensor(input, mask, level=0): """ This function takes in an input that contains the complete lod information, diff --git a/python/paddle/fluid/tests/unittests/test_select_input_output_op.py b/python/paddle/fluid/tests/unittests/test_select_input_output_op.py new file mode 100644 index 0000000000000000000000000000000000000000..092262da4711a1a8fe3e3604ab79e282e6d68a67 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_select_input_output_op.py @@ -0,0 +1,65 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +import paddle.fluid as fluid +import paddle.fluid.core as core +import paddle.fluid.layers as layers +from paddle.fluid.backward import append_backward +from paddle.fluid.executor import Executor +from paddle.fluid.framework import Program, program_guard +from paddle.fluid.layers.control_flow import select_input, select_output + + +class TestSplitMergeSelectedVarOps(unittest.TestCase): + def test_forward_backward(self): + branch_num = 9 + program = Program() + with program_guard(program): + x = layers.data(name='x', shape=[2], dtype='float32') + x.stop_gradient = False # For test gradient + mask = layers.data(name='mask', shape=[1], dtype='int32') + + outputs = [] + for i in range(branch_num): + out = program.current_block().create_var( + dtype='float32', type=core.VarDesc.VarType.LOD_TENSOR) + outputs.append(out) + + select_output(x, outputs, mask) + y = select_input(outputs, mask) + mean = layers.mean(y) + append_backward(mean) + + place = fluid.CUDAPlace(0) if core.is_compiled_with_cuda( + ) else fluid.CPUPlace() + exe = Executor(place) + + feed_x = np.asarray([1.3, -1.4]).astype(np.float32) + for i in range(branch_num): + feed_mask = np.asarray([i]).astype(np.int32) + ret = exe.run(program, + feed={'x': feed_x, + 'mask': feed_mask}, + fetch_list=[y.name, x.grad_name]) + x_grad = np.asarray([0.5, 0.5]).astype(np.float32) + self.assertTrue(np.allclose(np.asarray(ret[0]), feed_x)) + self.assertTrue(np.allclose(np.asarray(ret[1]), x_grad)) + + +if __name__ == '__main__': + unittest.main()