提交 6a5f6046 编写于 作者: M minqiyang

Support stop_gradients var in imperative backward

test=develop
上级 9e3155e0
......@@ -69,6 +69,15 @@ inline std::string GradVarName(const std::string& var_name) {
return result;
}
inline std::string OriginVarName(const std::string& grad_var_name) {
std::size_t pos = grad_var_name.find_last_of(kGradVarSuffix);
if (pos == std::string::npos) {
return grad_var_name;
} else {
return grad_var_name.substr(0, pos);
}
}
proto::VarType::Type GetDataTypeOfVar(const Variable* var);
const Tensor* GetLoDTensorOrSelectedRowsValueFromVar(const Variable& var);
Tensor* GetMutableLoDTensorOrSelectedRowsValueFromVar(Variable* var);
......
......@@ -288,3 +288,12 @@ TEST(OpKernel, multi_inputs) {
auto op = paddle::framework::OpRegistry::CreateOp(op_desc);
op->Run(scope, cpu_place);
}
TEST(Functions, all) {
std::string var_name("X");
std::string grad_var_name = paddle::framework::GradVarName(var_name);
ASSERT_EQ(grad_var_name.c_str(), "X@GRAD");
std::string original_var_name =
paddle::framework::OriginVarName(grad_var_name);
ASSERT_EQ(original_var_name.c_str(), "X");
}
......@@ -21,6 +21,7 @@
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/string/printf.h"
namespace paddle {
......@@ -31,8 +32,9 @@ using framework::Variable;
void AddTo(Variable* src, Variable* dst) {
framework::LoDTensor* dst_tensor = dst->GetMutable<framework::LoDTensor>();
framework::LoDTensor* src_tensor = src->GetMutable<framework::LoDTensor>();
PADDLE_ENFORCE(dst_tensor->numel() == src_tensor->numel(), "%lld vs %lld",
dst_tensor->numel(), src_tensor->numel());
PADDLE_ENFORCE(dst_tensor->numel() == src_tensor->numel(),
"dst_numel %lld vs. src_numel %lld", dst_tensor->numel(),
src_tensor->numel());
float* dst_data = dst_tensor->mutable_data<float>(platform::CPUPlace());
const float* src_data = src_tensor->data<float>();
for (size_t i = 0; i < src_tensor->numel(); ++i) {
......@@ -114,7 +116,7 @@ framework::LoDTensor& VarBase::Grad() {
std::map<std::string, std::vector<VarBase*>> OpBase::ApplyGrad() {
if (!grad_op_desc_) {
VLOG(3) << "op with no grad: " << op_desc_->Type();
LOG(WARNING) << "op with no grad: " << op_desc_->Type();
return {};
}
VLOG(3) << "op grad " << grad_op_desc_->Type();
......@@ -124,20 +126,18 @@ std::map<std::string, std::vector<VarBase*>> OpBase::ApplyGrad() {
for (auto it : grad_output_vars_) {
auto& outputs = grad_outputs[it.first];
for (size_t i = 0; i < it.second.size(); ++i) {
tmp_vars.emplace_back(new framework::Variable());
outputs.push_back(tmp_vars.back().get());
outputs.back()->GetMutable<framework::LoDTensor>();
// Allocate a new variable
Variable* tmp_var = new framework::Variable();
tmp_var->GetMutable<framework::LoDTensor>();
tmp_vars.emplace_back(tmp_var);
outputs.push_back(tmp_var);
}
grad_invar_desc.SetShape(
framework::vectorize(var->Get<framework::LoDTensor>().dims()));
VLOG(3)
<< "set op grad var desc's shape size "
<< framework::vectorize(var->Get<framework::LoDTensor>().dims()).size();
}
framework::RuntimeContext ctx(grad_input_vars_, grad_outputs);
// No need to do static infer shape here.
// No need to do compile time infer shape here.
// grad_op_desc_->InferShape(*block_);
grad_op_desc_->InferVarType(block_);
......@@ -156,11 +156,16 @@ std::map<std::string, std::vector<VarBase*>> OpBase::ApplyGrad() {
for (auto it : grad_output_vars_) {
auto& outputs = grad_outputs[it.first];
auto& origin_outputs = it.second;
auto& forward_inputs = input_vars_[framework::OriginVarName(it.first)];
for (size_t i = 0; i < outputs.size(); ++i) {
if (!forward_inputs[i]->stop_gradient_) {
framework::Variable* orig_grad = origin_outputs[i];
AddTo(outputs[i], orig_grad);
}
}
}
return input_vars_;
}
......
......@@ -57,7 +57,7 @@ class Tracer {
void Trace(OpBase* op,
const std::map<std::string, std::vector<VarBase*>>& inputs,
const std::map<std::string, std::vector<VarBase*>>& outputs,
framework::BlockDesc* block) {
framework::BlockDesc* block, const bool stop_gradient) {
std::map<std::string, VarBase*> vars;
framework::OpDesc* op_desc = op->op_desc_;
......
......@@ -152,7 +152,7 @@ PYBIND11_MODULE(core, m) {
[](const imperative::VarBase &self) { return self.stop_gradient_; },
[](imperative::VarBase &self, bool stop_gradient) {
self.stop_gradient_ = stop_gradient;
})
});
py::class_<imperative::OpBase, PyOpBase>(m, "OpBase", R"DOC()DOC")
.def(py::init<>())
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册