提交 6a5f6046 编写于 作者: M minqiyang

Support stop_gradients var in imperative backward

test=develop
上级 9e3155e0
...@@ -69,6 +69,15 @@ inline std::string GradVarName(const std::string& var_name) { ...@@ -69,6 +69,15 @@ inline std::string GradVarName(const std::string& var_name) {
return result; return result;
} }
inline std::string OriginVarName(const std::string& grad_var_name) {
std::size_t pos = grad_var_name.find_last_of(kGradVarSuffix);
if (pos == std::string::npos) {
return grad_var_name;
} else {
return grad_var_name.substr(0, pos);
}
}
proto::VarType::Type GetDataTypeOfVar(const Variable* var); proto::VarType::Type GetDataTypeOfVar(const Variable* var);
const Tensor* GetLoDTensorOrSelectedRowsValueFromVar(const Variable& var); const Tensor* GetLoDTensorOrSelectedRowsValueFromVar(const Variable& var);
Tensor* GetMutableLoDTensorOrSelectedRowsValueFromVar(Variable* var); Tensor* GetMutableLoDTensorOrSelectedRowsValueFromVar(Variable* var);
......
...@@ -288,3 +288,12 @@ TEST(OpKernel, multi_inputs) { ...@@ -288,3 +288,12 @@ TEST(OpKernel, multi_inputs) {
auto op = paddle::framework::OpRegistry::CreateOp(op_desc); auto op = paddle::framework::OpRegistry::CreateOp(op_desc);
op->Run(scope, cpu_place); op->Run(scope, cpu_place);
} }
TEST(Functions, all) {
std::string var_name("X");
std::string grad_var_name = paddle::framework::GradVarName(var_name);
ASSERT_EQ(grad_var_name.c_str(), "X@GRAD");
std::string original_var_name =
paddle::framework::OriginVarName(grad_var_name);
ASSERT_EQ(original_var_name.c_str(), "X");
}
...@@ -21,6 +21,7 @@ ...@@ -21,6 +21,7 @@
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/string/printf.h" #include "paddle/fluid/string/printf.h"
namespace paddle { namespace paddle {
...@@ -31,8 +32,9 @@ using framework::Variable; ...@@ -31,8 +32,9 @@ using framework::Variable;
void AddTo(Variable* src, Variable* dst) { void AddTo(Variable* src, Variable* dst) {
framework::LoDTensor* dst_tensor = dst->GetMutable<framework::LoDTensor>(); framework::LoDTensor* dst_tensor = dst->GetMutable<framework::LoDTensor>();
framework::LoDTensor* src_tensor = src->GetMutable<framework::LoDTensor>(); framework::LoDTensor* src_tensor = src->GetMutable<framework::LoDTensor>();
PADDLE_ENFORCE(dst_tensor->numel() == src_tensor->numel(), "%lld vs %lld", PADDLE_ENFORCE(dst_tensor->numel() == src_tensor->numel(),
dst_tensor->numel(), src_tensor->numel()); "dst_numel %lld vs. src_numel %lld", dst_tensor->numel(),
src_tensor->numel());
float* dst_data = dst_tensor->mutable_data<float>(platform::CPUPlace()); float* dst_data = dst_tensor->mutable_data<float>(platform::CPUPlace());
const float* src_data = src_tensor->data<float>(); const float* src_data = src_tensor->data<float>();
for (size_t i = 0; i < src_tensor->numel(); ++i) { for (size_t i = 0; i < src_tensor->numel(); ++i) {
...@@ -114,7 +116,7 @@ framework::LoDTensor& VarBase::Grad() { ...@@ -114,7 +116,7 @@ framework::LoDTensor& VarBase::Grad() {
std::map<std::string, std::vector<VarBase*>> OpBase::ApplyGrad() { std::map<std::string, std::vector<VarBase*>> OpBase::ApplyGrad() {
if (!grad_op_desc_) { if (!grad_op_desc_) {
VLOG(3) << "op with no grad: " << op_desc_->Type(); LOG(WARNING) << "op with no grad: " << op_desc_->Type();
return {}; return {};
} }
VLOG(3) << "op grad " << grad_op_desc_->Type(); VLOG(3) << "op grad " << grad_op_desc_->Type();
...@@ -124,20 +126,18 @@ std::map<std::string, std::vector<VarBase*>> OpBase::ApplyGrad() { ...@@ -124,20 +126,18 @@ std::map<std::string, std::vector<VarBase*>> OpBase::ApplyGrad() {
for (auto it : grad_output_vars_) { for (auto it : grad_output_vars_) {
auto& outputs = grad_outputs[it.first]; auto& outputs = grad_outputs[it.first];
for (size_t i = 0; i < it.second.size(); ++i) { for (size_t i = 0; i < it.second.size(); ++i) {
tmp_vars.emplace_back(new framework::Variable()); // Allocate a new variable
outputs.push_back(tmp_vars.back().get()); Variable* tmp_var = new framework::Variable();
outputs.back()->GetMutable<framework::LoDTensor>(); tmp_var->GetMutable<framework::LoDTensor>();
tmp_vars.emplace_back(tmp_var);
outputs.push_back(tmp_var);
} }
grad_invar_desc.SetShape(
framework::vectorize(var->Get<framework::LoDTensor>().dims()));
VLOG(3)
<< "set op grad var desc's shape size "
<< framework::vectorize(var->Get<framework::LoDTensor>().dims()).size();
} }
framework::RuntimeContext ctx(grad_input_vars_, grad_outputs); framework::RuntimeContext ctx(grad_input_vars_, grad_outputs);
// No need to do static infer shape here. // No need to do compile time infer shape here.
// grad_op_desc_->InferShape(*block_); // grad_op_desc_->InferShape(*block_);
grad_op_desc_->InferVarType(block_); grad_op_desc_->InferVarType(block_);
...@@ -156,11 +156,16 @@ std::map<std::string, std::vector<VarBase*>> OpBase::ApplyGrad() { ...@@ -156,11 +156,16 @@ std::map<std::string, std::vector<VarBase*>> OpBase::ApplyGrad() {
for (auto it : grad_output_vars_) { for (auto it : grad_output_vars_) {
auto& outputs = grad_outputs[it.first]; auto& outputs = grad_outputs[it.first];
auto& origin_outputs = it.second; auto& origin_outputs = it.second;
auto& forward_inputs = input_vars_[framework::OriginVarName(it.first)];
for (size_t i = 0; i < outputs.size(); ++i) { for (size_t i = 0; i < outputs.size(); ++i) {
if (!forward_inputs[i]->stop_gradient_) {
framework::Variable* orig_grad = origin_outputs[i]; framework::Variable* orig_grad = origin_outputs[i];
AddTo(outputs[i], orig_grad); AddTo(outputs[i], orig_grad);
} }
} }
}
return input_vars_; return input_vars_;
} }
......
...@@ -57,7 +57,7 @@ class Tracer { ...@@ -57,7 +57,7 @@ class Tracer {
void Trace(OpBase* op, void Trace(OpBase* op,
const std::map<std::string, std::vector<VarBase*>>& inputs, const std::map<std::string, std::vector<VarBase*>>& inputs,
const std::map<std::string, std::vector<VarBase*>>& outputs, const std::map<std::string, std::vector<VarBase*>>& outputs,
framework::BlockDesc* block) { framework::BlockDesc* block, const bool stop_gradient) {
std::map<std::string, VarBase*> vars; std::map<std::string, VarBase*> vars;
framework::OpDesc* op_desc = op->op_desc_; framework::OpDesc* op_desc = op->op_desc_;
......
...@@ -152,7 +152,7 @@ PYBIND11_MODULE(core, m) { ...@@ -152,7 +152,7 @@ PYBIND11_MODULE(core, m) {
[](const imperative::VarBase &self) { return self.stop_gradient_; }, [](const imperative::VarBase &self) { return self.stop_gradient_; },
[](imperative::VarBase &self, bool stop_gradient) { [](imperative::VarBase &self, bool stop_gradient) {
self.stop_gradient_ = stop_gradient; self.stop_gradient_ = stop_gradient;
}) });
py::class_<imperative::OpBase, PyOpBase>(m, "OpBase", R"DOC()DOC") py::class_<imperative::OpBase, PyOpBase>(m, "OpBase", R"DOC()DOC")
.def(py::init<>()) .def(py::init<>())
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册