diff --git a/paddle/framework/CMakeLists.txt b/paddle/framework/CMakeLists.txt index a29a81c994c8168afe4b2088d0f7cc6ad3807c5b..26d93336b131e1d2bf6dc3188086774b5abc4f0b 100644 --- a/paddle/framework/CMakeLists.txt +++ b/paddle/framework/CMakeLists.txt @@ -31,3 +31,5 @@ add_dependencies(framework_py_proto framework_py_proto_init) cc_library(net SRCS net.cc DEPS op_registry) cc_test(net_op_test SRCS net_op_test.cc DEPS net add_op mul_op sigmoid_op softmax_op fc_op) + +cc_library(backward SRCS backward.cc DEPS net) diff --git a/paddle/framework/backward.cc b/paddle/framework/backward.cc new file mode 100644 index 0000000000000000000000000000000000000000..11690342185ba6046df2b0f9121bb9bd8cd9a073 --- /dev/null +++ b/paddle/framework/backward.cc @@ -0,0 +1,93 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include +#include + +namespace paddle { +namespace framework { + +static bool AllInSet(const std::vector& names, + const std::string& suffix, + const std::unordered_set& set) { + for (auto& name : names) { + if (set.find(name + suffix) == set.end()) { + return false; + } + } + return true; +} + +static std::vector InSetIdx(const std::vector& names, + const std::string& suffix, + const std::unordered_set& set) { + std::vector ret_val; + ret_val.reserve(names.size()); + for (size_t i = 0; i < names.size(); ++i) { + if (set.find(names[i] + suffix) != set.end()) { + ret_val.push_back(i); + } + } + return ret_val; +} + +static std::shared_ptr EmptyOp() { + auto net_op = std::make_shared(); + net_op->CompleteAddOp(); + return net_op; +} + +static std::shared_ptr BackwardImpl( + const OperatorBase& forwardOp, + std::unordered_set& no_grad_names, int& uniq_id) { + if (AllInSet(forwardOp.inputs_, OperatorBase::GRAD_VAR_SUFFIX(), + no_grad_names)) { + return EmptyOp(); + } + + if (AllInSet(forwardOp.outputs_, OperatorBase::GRAD_VAR_SUFFIX(), + no_grad_names)) { + for (auto& name : forwardOp.inputs_) { + // Mark all input is not need + no_grad_names.insert(name + OperatorBase::GRAD_VAR_SUFFIX()); + } + return EmptyOp(); + } + + auto* net = new NetOp(); + + if (forwardOp.IsNetOp()) { + //! TODO(dzh) + } else { + //! TODO(fjy) + } + + net->CompleteAddOp(); + return std::shared_ptr(net); +} + +extern std::shared_ptr Backward( + const std::shared_ptr& forwardOp, + const std::unordered_set& no_grad_vars) { + std::unordered_set no_grad_names; + no_grad_names.reserve(no_grad_vars.size()); + + for (auto& name : no_grad_vars) { + no_grad_names.insert(name + OperatorBase::GRAD_VAR_SUFFIX()); + } + int uid = 0; + return BackwardImpl(*forwardOp, no_grad_names, uid); +} +} // namespace framework +} // namespace paddle diff --git a/paddle/framework/backward.h b/paddle/framework/backward.h new file mode 100644 index 0000000000000000000000000000000000000000..e835ef6351102686d68d657fdc5a6a2913ace3e6 --- /dev/null +++ b/paddle/framework/backward.h @@ -0,0 +1,31 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#pragma once +#include +#include "operator.h" +namespace paddle { +namespace framework { + +/** + * @brief + * @param forwardOp + * @param no_grad_vars ignored input name of forward + * @return + */ +extern std::shared_ptr Backward( + const std::shared_ptr& forwardOp, + const std::unordered_set& no_grad_vars); +} // namespace framework +} // namespace paddle diff --git a/paddle/framework/op_registry.h b/paddle/framework/op_registry.h index f16deae028d76dc40d6bc589648b461c430c3c98..5bcd7ac927952f91c8b292f061b50a7ae16dbf3d 100644 --- a/paddle/framework/op_registry.h +++ b/paddle/framework/op_registry.h @@ -302,6 +302,8 @@ class OpRegistry { static std::shared_ptr CreateGradOp( std::shared_ptr op) { + PADDLE_ENFORCE(!op->IsNetOp(), + "Use framework::Backward to get backward ops"); GradOpBuilder builder(op.get()); std::shared_ptr grad_op(builder.Build()); grad_op->Init();