未验证 提交 c5a191bb 编写于 作者: K kangguangli 提交者: GitHub

[NewIR] add stop_gradient attribute for defining op (#55235)

* add stop_gradient attribute for defining op

* modify by reviews

* fix
上级 4905a247
...@@ -330,10 +330,6 @@ std::vector<ir::OpResult> OpTranscriber::GenerateOperationInput( ...@@ -330,10 +330,6 @@ std::vector<ir::OpResult> OpTranscriber::GenerateOperationInput(
std::set<std::string> yaml_input_set; std::set<std::string> yaml_input_set;
for (const auto& info : input_infos) { for (const auto& info : input_infos) {
if (auto special_handler = this->GetSpecialInputHandlers(info.name)) {
continue;
}
std::string legacy_input_name = std::string legacy_input_name =
op_normalizer.GetLegacyArgName(op_desc.Type(), info.name); op_normalizer.GetLegacyArgName(op_desc.Type(), info.name);
...@@ -381,7 +377,6 @@ std::vector<ir::OpResult> OpTranscriber::GenerateOperationInput( ...@@ -381,7 +377,6 @@ std::vector<ir::OpResult> OpTranscriber::GenerateOperationInput(
std::vector<std::string> legacy_input_vars; std::vector<std::string> legacy_input_vars;
// return empty OpResult if this arg is optional and not shown in OpDesc // return empty OpResult if this arg is optional and not shown in OpDesc
// TODO(lyk): HasInput doesnot consider variadic attribute
if (op_desc.HasInput(legacy_input_name, true)) { if (op_desc.HasInput(legacy_input_name, true)) {
legacy_input_vars = op_desc.Input(legacy_input_name, true); legacy_input_vars = op_desc.Input(legacy_input_name, true);
} }
...@@ -436,6 +431,10 @@ std::vector<ir::OpResult> OpTranscriber::GenerateOperationInput( ...@@ -436,6 +431,10 @@ std::vector<ir::OpResult> OpTranscriber::GenerateOperationInput(
// if src type is Tensor // if src type is Tensor
if (!is_vector) { if (!is_vector) {
IR_ENFORCE(legacy_input_vars.size() == 1u,
"Input %s not found when parsing op %s",
info.name,
op_desc.Type());
auto defining_info = (*param_map)[legacy_input_vars[0]]; auto defining_info = (*param_map)[legacy_input_vars[0]];
op_inputs.push_back(defining_info.value); op_inputs.push_back(defining_info.value);
......
...@@ -24,6 +24,7 @@ ...@@ -24,6 +24,7 @@
#include "paddle/fluid/ir_adaptor/translator/type_translator.h" #include "paddle/fluid/ir_adaptor/translator/type_translator.h"
#include "paddle/ir/core/attribute.h" #include "paddle/ir/core/attribute.h"
#include "paddle/ir/core/block.h" #include "paddle/ir/core/block.h"
#include "paddle/ir/core/builtin_attribute.h"
#include "paddle/ir/core/builtin_op.h" #include "paddle/ir/core/builtin_op.h"
#include "paddle/ir/core/builtin_type.h" #include "paddle/ir/core/builtin_type.h"
#include "paddle/ir/core/enforce.h" #include "paddle/ir/core/enforce.h"
...@@ -38,17 +39,19 @@ using ProgramDesc = ::paddle::framework::ProgramDesc; ...@@ -38,17 +39,19 @@ using ProgramDesc = ::paddle::framework::ProgramDesc;
using BlockDesc = ::paddle::framework::BlockDesc; using BlockDesc = ::paddle::framework::BlockDesc;
using VarDesc = ::paddle::framework::VarDesc; using VarDesc = ::paddle::framework::VarDesc;
const std::unordered_set<std::string> ProgramTranslator::no_cast_var_names = {
"feed",
"fetch",
};
constexpr char kAttrStopGradients[] = "stop_gradient";
ProgramTranslator::ProgramTranslator(const ProgramDesc* legacy_program, ProgramTranslator::ProgramTranslator(const ProgramDesc* legacy_program,
ir::Program* program) ir::Program* program)
: legacy_program_(legacy_program), program_(program) { : legacy_program_(legacy_program), program_(program) {
ctx_ = ir::IrContext::Instance(); ctx_ = ir::IrContext::Instance();
} }
const std::unordered_set<std::string> ProgramTranslator::no_cast_var_names = {
"feed",
"fetch",
};
void ProgramTranslator::Translate() { void ProgramTranslator::Translate() {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
legacy_program_->Size(), legacy_program_->Size(),
...@@ -71,6 +74,11 @@ void ProgramTranslator::Translate() { ...@@ -71,6 +74,11 @@ void ProgramTranslator::Translate() {
const BlockDesc& block = legacy_program_->Block(block_idx); const BlockDesc& block = legacy_program_->Block(block_idx);
SetParameterFromSingleBlock(block); SetParameterFromSingleBlock(block);
} }
for (size_t block_idx = 0; block_idx < legacy_program_->Size(); block_idx++) {
const BlockDesc& block = legacy_program_->Block(block_idx);
SetStopGradientAttributeForAllValue(block);
}
} }
inline ir::Operation* InsertGetParamaterOp(ir::IrContext* ctx, inline ir::Operation* InsertGetParamaterOp(ir::IrContext* ctx,
...@@ -198,5 +206,35 @@ void ProgramTranslator::SetParameterFromSingleBlock(const BlockDesc& block) { ...@@ -198,5 +206,35 @@ void ProgramTranslator::SetParameterFromSingleBlock(const BlockDesc& block) {
} }
} }
void ProgramTranslator::SetStopGradientAttributeForAllValue(
const BlockDesc& block) {
// Currently we set stop gradient for operation that generated a value
// connected with VarDesc
for (const auto& [var_name, value_info] : param_map_) {
VLOG(10) << "[op translated][stop gradient]" << var_name;
VarDesc* var = block.FindVarRecursive(var_name);
if (var == nullptr) {
continue;
}
ir::OpResult value = value_info.value;
auto* defining_op = value.owner();
VLOG(8) << "[op translated][stop gradient]" << var_name
<< " from: " << defining_op->name();
std::vector<ir::Attribute> stop_gradients;
if (defining_op->HasAttribute(kAttrStopGradients)) {
stop_gradients = defining_op->attribute(kAttrStopGradients)
.dyn_cast<ir::ArrayAttribute>()
.data();
} else {
stop_gradients = std::vector<ir::Attribute>(
defining_op->num_results(), ir::BoolAttribute::get(ctx_, false));
}
stop_gradients[value.GetResultIndex()] =
ir::BoolAttribute::get(ctx_, var->StopGradient());
defining_op->set_attribute(kAttrStopGradients,
ir::ArrayAttribute::get(ctx_, stop_gradients));
}
}
} // namespace translator } // namespace translator
} // namespace paddle } // namespace paddle
...@@ -72,12 +72,13 @@ class ProgramTranslator { ...@@ -72,12 +72,13 @@ class ProgramTranslator {
/// 2. "fetch", the output variable of fetch op /// 2. "fetch", the output variable of fetch op
/// However, new feed has no input and new fetch has no output /// However, new feed has no input and new fetch has no output
/// So we don't handle these two vairables when /// So we don't handle these two vairables when
/// `ExtractParameterFromSingleBlock` /// `Get/SetParameterFromSingleBlock`
static const std::unordered_set<std::string> no_cast_var_names; static const std::unordered_set<std::string> no_cast_var_names;
void GetParameterForSingleBlock(const BlockDesc& block); void GetParameterForSingleBlock(const BlockDesc& block);
void InsertOperationToSingleBlock(const BlockDesc& block); void InsertOperationToSingleBlock(const BlockDesc& block);
void SetParameterFromSingleBlock(const BlockDesc& block); void SetParameterFromSingleBlock(const BlockDesc& block);
void SetStopGradientAttributeForAllValue(const BlockDesc& block);
}; };
} // namespace translator } // namespace translator
......
...@@ -205,6 +205,11 @@ std::string Operation::name() const { ...@@ -205,6 +205,11 @@ std::string Operation::name() const {
return p_name ? p_name : ""; return p_name ? p_name : "";
} }
Attribute Operation::attribute(const std::string &key) const {
IR_ENFORCE(HasAttribute(key), "operation(%s): no attribute %s", name(), key);
return attributes_.at(key);
}
Region *Operation::GetParentRegion() const { Region *Operation::GetParentRegion() const {
return parent_ ? parent_->GetParent() : nullptr; return parent_ ? parent_->GetParent() : nullptr;
} }
......
...@@ -64,10 +64,16 @@ class IR_API alignas(8) Operation final { ...@@ -64,10 +64,16 @@ class IR_API alignas(8) Operation final {
const AttributeMap &attributes() const { return attributes_; } const AttributeMap &attributes() const { return attributes_; }
void SetAttribute(const std::string &key, Attribute value) { void set_attribute(const std::string &key, Attribute value) {
attributes_[key] = value; attributes_[key] = value;
} }
Attribute attribute(const std::string &key) const;
bool HasAttribute(const std::string &key) const {
return attributes_.find(key) != attributes_.end();
}
ir::OpInfo info() const { return info_; } ir::OpInfo info() const { return info_; }
uint32_t num_results() const { return num_results_; } uint32_t num_results() const { return num_results_; }
......
...@@ -274,6 +274,6 @@ TEST(op_test, module_op_death) { ...@@ -274,6 +274,6 @@ TEST(op_test, module_op_death) {
EXPECT_EQ(program.module_op().program(), &program); EXPECT_EQ(program.module_op().program(), &program);
EXPECT_EQ(program.module_op().ir_context(), ctx); EXPECT_EQ(program.module_op().ir_context(), ctx);
program.module_op()->SetAttribute("program", program.module_op()->set_attribute("program",
ir::PointerAttribute::get(ctx, &program)); ir::PointerAttribute::get(ctx, &program));
} }
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册