提交 e32e3068 编写于 作者: F fengjiayi

Develop backward building precess of single op

上级 a2dc9614
......@@ -12,8 +12,9 @@
See the License for the specific language governing permissions and
limitations under the License. */
#include <paddle/framework/backward.h>
#include <paddle/framework/net.h>
#include "paddle/framework/backward.h"
#include "paddle/framework/net.h"
#include "paddle/framework/op_registry.h"
namespace paddle {
namespace framework {
......@@ -71,6 +72,24 @@ static std::shared_ptr<OperatorBase> BackwardImpl(
//! TODO(dzh)
} else {
//! TODO(fjy)
std::shared_ptr<OperatorBase> grad_op = OpRegistry::CreateGradOp(forwardOp);
for (std::string& grad_input : grad_op->inputs_) {
if (no_grad_names.count(grad_input)) {
std::string prefix = grad_input.substr(
0, grad_input.size() - OperatorBase::GRAD_VAR_SUFFIX().size());
grad_input = prefix + OperatorBase::ZERO_VAR_SUFFIX();
std::vector<std::string> fill_zeros_in = {prefix};
std::vector<std::string> fill_zeros_out = {grad_input};
net.AddOp(OpRegistry::CreateOp("fill_zeros_like", fill_zeros_in,
fill_zeros_out, AttributeMap()));
}
}
for (std::string& grad_output : grad_op->output_) {
if (no_grad_names.count(grad_output)) {
grad_output = OperatorBase::EMPTY_VAR_NAME();
}
}
net.AddOp(grad_op);
}
net->CompleteAddOp();
......
......@@ -67,6 +67,9 @@ class OperatorBase {
/// e.g. Variable "x@GRAD" is the gradient of varibale "x".
static std::string GRAD_VAR_SUFFIX() { return "@GRAD"; }
/// Variables with this suffix are supposed to be filled up with zeros.
static std::string ZERO_VAR_SUFFIX() { return "@ZERO"; }
virtual ~OperatorBase() {}
template <typename T>
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册