未验证 提交 c5a191bb 编写于 作者: K kangguangli 提交者: GitHub

[NewIR] add stop_gradient attribute for defining op (#55235)

* add stop_gradient attribute for defining op

* modify by reviews

* fix
上级 4905a247
......@@ -330,10 +330,6 @@ std::vector<ir::OpResult> OpTranscriber::GenerateOperationInput(
std::set<std::string> yaml_input_set;
for (const auto& info : input_infos) {
if (auto special_handler = this->GetSpecialInputHandlers(info.name)) {
continue;
}
std::string legacy_input_name =
op_normalizer.GetLegacyArgName(op_desc.Type(), info.name);
......@@ -381,7 +377,6 @@ std::vector<ir::OpResult> OpTranscriber::GenerateOperationInput(
std::vector<std::string> legacy_input_vars;
// return empty OpResult if this arg is optional and not shown in OpDesc
// TODO(lyk): HasInput doesnot consider variadic attribute
if (op_desc.HasInput(legacy_input_name, true)) {
legacy_input_vars = op_desc.Input(legacy_input_name, true);
}
......@@ -436,6 +431,10 @@ std::vector<ir::OpResult> OpTranscriber::GenerateOperationInput(
// if src type is Tensor
if (!is_vector) {
IR_ENFORCE(legacy_input_vars.size() == 1u,
"Input %s not found when parsing op %s",
info.name,
op_desc.Type());
auto defining_info = (*param_map)[legacy_input_vars[0]];
op_inputs.push_back(defining_info.value);
......
......@@ -24,6 +24,7 @@
#include "paddle/fluid/ir_adaptor/translator/type_translator.h"
#include "paddle/ir/core/attribute.h"
#include "paddle/ir/core/block.h"
#include "paddle/ir/core/builtin_attribute.h"
#include "paddle/ir/core/builtin_op.h"
#include "paddle/ir/core/builtin_type.h"
#include "paddle/ir/core/enforce.h"
......@@ -38,17 +39,19 @@ using ProgramDesc = ::paddle::framework::ProgramDesc;
using BlockDesc = ::paddle::framework::BlockDesc;
using VarDesc = ::paddle::framework::VarDesc;
const std::unordered_set<std::string> ProgramTranslator::no_cast_var_names = {
"feed",
"fetch",
};
constexpr char kAttrStopGradients[] = "stop_gradient";
ProgramTranslator::ProgramTranslator(const ProgramDesc* legacy_program,
ir::Program* program)
: legacy_program_(legacy_program), program_(program) {
ctx_ = ir::IrContext::Instance();
}
const std::unordered_set<std::string> ProgramTranslator::no_cast_var_names = {
"feed",
"fetch",
};
void ProgramTranslator::Translate() {
PADDLE_ENFORCE_EQ(
legacy_program_->Size(),
......@@ -71,6 +74,11 @@ void ProgramTranslator::Translate() {
const BlockDesc& block = legacy_program_->Block(block_idx);
SetParameterFromSingleBlock(block);
}
for (size_t block_idx = 0; block_idx < legacy_program_->Size(); block_idx++) {
const BlockDesc& block = legacy_program_->Block(block_idx);
SetStopGradientAttributeForAllValue(block);
}
}
inline ir::Operation* InsertGetParamaterOp(ir::IrContext* ctx,
......@@ -198,5 +206,35 @@ void ProgramTranslator::SetParameterFromSingleBlock(const BlockDesc& block) {
}
}
void ProgramTranslator::SetStopGradientAttributeForAllValue(
const BlockDesc& block) {
// Currently we set stop gradient for operation that generated a value
// connected with VarDesc
for (const auto& [var_name, value_info] : param_map_) {
VLOG(10) << "[op translated][stop gradient]" << var_name;
VarDesc* var = block.FindVarRecursive(var_name);
if (var == nullptr) {
continue;
}
ir::OpResult value = value_info.value;
auto* defining_op = value.owner();
VLOG(8) << "[op translated][stop gradient]" << var_name
<< " from: " << defining_op->name();
std::vector<ir::Attribute> stop_gradients;
if (defining_op->HasAttribute(kAttrStopGradients)) {
stop_gradients = defining_op->attribute(kAttrStopGradients)
.dyn_cast<ir::ArrayAttribute>()
.data();
} else {
stop_gradients = std::vector<ir::Attribute>(
defining_op->num_results(), ir::BoolAttribute::get(ctx_, false));
}
stop_gradients[value.GetResultIndex()] =
ir::BoolAttribute::get(ctx_, var->StopGradient());
defining_op->set_attribute(kAttrStopGradients,
ir::ArrayAttribute::get(ctx_, stop_gradients));
}
}
} // namespace translator
} // namespace paddle
......@@ -72,12 +72,13 @@ class ProgramTranslator {
/// 2. "fetch", the output variable of fetch op
/// However, new feed has no input and new fetch has no output
/// So we don't handle these two vairables when
/// `ExtractParameterFromSingleBlock`
/// `Get/SetParameterFromSingleBlock`
static const std::unordered_set<std::string> no_cast_var_names;
void GetParameterForSingleBlock(const BlockDesc& block);
void InsertOperationToSingleBlock(const BlockDesc& block);
void SetParameterFromSingleBlock(const BlockDesc& block);
void SetStopGradientAttributeForAllValue(const BlockDesc& block);
};
} // namespace translator
......
......@@ -205,6 +205,11 @@ std::string Operation::name() const {
return p_name ? p_name : "";
}
Attribute Operation::attribute(const std::string &key) const {
IR_ENFORCE(HasAttribute(key), "operation(%s): no attribute %s", name(), key);
return attributes_.at(key);
}
Region *Operation::GetParentRegion() const {
return parent_ ? parent_->GetParent() : nullptr;
}
......
......@@ -64,10 +64,16 @@ class IR_API alignas(8) Operation final {
const AttributeMap &attributes() const { return attributes_; }
void SetAttribute(const std::string &key, Attribute value) {
void set_attribute(const std::string &key, Attribute value) {
attributes_[key] = value;
}
Attribute attribute(const std::string &key) const;
bool HasAttribute(const std::string &key) const {
return attributes_.find(key) != attributes_.end();
}
ir::OpInfo info() const { return info_; }
uint32_t num_results() const { return num_results_; }
......
......@@ -274,6 +274,6 @@ TEST(op_test, module_op_death) {
EXPECT_EQ(program.module_op().program(), &program);
EXPECT_EQ(program.module_op().ir_context(), ctx);
program.module_op()->SetAttribute("program",
ir::PointerAttribute::get(ctx, &program));
program.module_op()->set_attribute("program",
ir::PointerAttribute::get(ctx, &program));
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册