From 7c1755d93f7f046432b596aac6c271edc676b8ae Mon Sep 17 00:00:00 2001 From: Yu Yang Date: Mon, 13 Nov 2017 18:31:22 -0800 Subject: [PATCH] Assign Operator. (#5531) * Assign Operator. Out=X, when type in [LoDTensor/SelectedRows/LoDTensorArray] * Follow comments --- paddle/framework/var_type.h | 22 +++ paddle/operators/assign_op.cc | 138 ++++++++++++++++++ .../v2/framework/tests/test_assign_op.py | 21 +++ 3 files changed, 181 insertions(+) create mode 100644 paddle/operators/assign_op.cc create mode 100644 python/paddle/v2/framework/tests/test_assign_op.py diff --git a/paddle/framework/var_type.h b/paddle/framework/var_type.h index d060196bb2c..0f19870bec3 100644 --- a/paddle/framework/var_type.h +++ b/paddle/framework/var_type.h @@ -27,10 +27,32 @@ inline VarDesc::VarType ToVarType(std::type_index type) { return VarDesc_VarType_LOD_RANK_TABLE; } else if (type.hash_code() == typeid(LoDTensorArray).hash_code()) { return VarDesc_VarType_LOD_TENSOR_ARRAY; + } else if (type.hash_code() == typeid(SelectedRows).hash_code()) { + return VarDesc_VarType_SELECTED_ROWS; } else { PADDLE_THROW("ToVarType:Unsupported type %s", type.name()); } } +template +inline void VisitVarType(const Variable& var, Visitor visitor) { + switch (ToVarType(var.Type())) { + case VarDesc_VarType_LOD_TENSOR: + visitor(var.Get()); + return; + case VarDesc_VarType_LOD_RANK_TABLE: + visitor(var.Get()); + return; + case VarDesc_VarType_LOD_TENSOR_ARRAY: + visitor(var.Get()); + return; + case VarDesc_VarType_SELECTED_ROWS: + visitor(var.Get()); + return; + default: + PADDLE_THROW("Not supported visit type, %d", ToVarType(var.Type())); + } +} + } // namespace framework } // namespace paddle diff --git a/paddle/operators/assign_op.cc b/paddle/operators/assign_op.cc new file mode 100644 index 00000000000..609e915b932 --- /dev/null +++ b/paddle/operators/assign_op.cc @@ -0,0 +1,138 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include "paddle/framework/data_type.h" +#include "paddle/framework/op_registry.h" +#include "paddle/framework/var_type.h" + +namespace paddle { +namespace operators { +class AssignFunctor { + public: + AssignFunctor(framework::Variable *out, + const platform::DeviceContext &dev_ctx) + : out_(out), dev_ctx_(dev_ctx) {} + + void operator()(const framework::LoDTensor &lod_tensor) const { + auto &out_tensor = *out_->GetMutable(); + copy_tensor(lod_tensor, &out_tensor); + } + + void operator()(const framework::LoDTensorArray &array) const { + auto &out_array = *out_->GetMutable(); + out_array.resize(array.size()); + for (size_t i = 0; i < array.size(); ++i) { + copy_tensor(array[i], &out_array[i]); + } + } + + void operator()(const framework::SelectedRows &rows) const { + framework::SelectedRows &out_rows = + *out_->GetMutable(); + out_rows.set_rows(rows.rows()); + out_rows.set_height(rows.height()); + auto &t = rows.value(); + out_rows.mutable_value()->CopyFrom(t, t.place(), dev_ctx_); + } + + template + void operator()(const T &v) const { + PADDLE_THROW("Not support type for assign op %s", typeid(T).name()); + } + + private: + void copy_tensor(const framework::LoDTensor &lod_tensor, + framework::LoDTensor *out) const { + auto &out_tensor = *out; + out_tensor.CopyFrom(lod_tensor, lod_tensor.place(), dev_ctx_); + out_tensor.set_lod(lod_tensor.lod()); + } + + framework::Variable *out_; + const platform::DeviceContext &dev_ctx_; +}; + +class AssignOp : public framework::OperatorBase { + public: + AssignOp(const std::string &type, const framework::VariableNameMap &inputs, + const framework::VariableNameMap &outputs, + const framework::AttributeMap &attrs) + : OperatorBase(type, inputs, outputs, attrs) {} + void Run(const framework::Scope &scope, + const platform::DeviceContext &dev_ctx) const override { + auto *x = scope.FindVar(Input("X")); + if (x == nullptr) { + return; + } + auto *out = scope.FindVar(Output("Out")); + PADDLE_ENFORCE( + out != nullptr, + "The Output(Out) should not be null if the Input(X) is set."); + framework::VisitVarType(*x, AssignFunctor(out, dev_ctx)); + } +}; + +class AssignOpProtoMaker : public framework::OpProtoAndCheckerMaker { + public: + AssignOpProtoMaker(framework::OpProto *proto, + framework::OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("X", + "(LoDTensor, SelectedRows or LoDTensorArray) The input variable " + "could be LoDTensor, SelectedRows or LoDTensorArray.") + .AsDispensable(); + AddOutput("Out", + "(LoDTensor, SelectedRows or LoDTensorArray) The type of output " + "is the same as input X."); + AddComment(R"DOC(Assign Operator + +Out = X, when type in [LoDTensor/SelectedRows/LoDTensorArray] +raise error if the type is not listed above. +)DOC"); + } +}; + +class AssignInferShape : public framework::InferShapeBase { + public: + void operator()(framework::InferShapeContext *context) const override { + if (context->HasInput("X")) { + auto type = context->GetInputsVarType("X")[0]; + if (type == framework::VarDesc_VarType_SELECTED_ROWS || + type == framework::VarDesc_VarType_LOD_TENSOR) { + context->SetOutputDim("Out", context->GetInputDim("X")); + } + } + } +}; + +class AssignGradMaker : public framework::SingleGradOpDescMaker { + public: + using framework::SingleGradOpDescMaker::SingleGradOpDescMaker; + + protected: + std::unique_ptr Apply() const override { + auto *op = new framework::OpDescBind(); + op->SetType("assign"); + op->SetInput("X", OutputGrad("Out")); + op->SetOutput("Out", InputGrad("X")); + return std::unique_ptr(op); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OPERATOR(assign, ops::AssignOp, ops::AssignGradMaker, + ops::AssignInferShape, ops::AssignOpProtoMaker); diff --git a/python/paddle/v2/framework/tests/test_assign_op.py b/python/paddle/v2/framework/tests/test_assign_op.py new file mode 100644 index 00000000000..1b0c145f1a6 --- /dev/null +++ b/python/paddle/v2/framework/tests/test_assign_op.py @@ -0,0 +1,21 @@ +import op_test +import numpy +import unittest + + +class TestAssignOp(op_test.OpTest): + def setUp(self): + self.op_type = "assign" + x = numpy.random.random(size=(100, 10)) + self.inputs = {'X': x} + self.outputs = {'Out': x} + + def test_forward(self): + self.check_output() + + def test_backward(self): + self.check_grad(['X'], 'Out') + + +if __name__ == '__main__': + unittest.main() -- GitLab