diff --git a/paddle/fluid/ir_adaptor/translator/op_translator.cc b/paddle/fluid/ir_adaptor/translator/op_translator.cc index aace711c10bf7947fa759f5b1e9a98ebfe1cdf18..cef8ab1a88bb78343fb2683c8beb8d3420d2bd69 100644 --- a/paddle/fluid/ir_adaptor/translator/op_translator.cc +++ b/paddle/fluid/ir_adaptor/translator/op_translator.cc @@ -330,10 +330,6 @@ std::vector OpTranscriber::GenerateOperationInput( std::set yaml_input_set; for (const auto& info : input_infos) { - if (auto special_handler = this->GetSpecialInputHandlers(info.name)) { - continue; - } - std::string legacy_input_name = op_normalizer.GetLegacyArgName(op_desc.Type(), info.name); @@ -381,7 +377,6 @@ std::vector OpTranscriber::GenerateOperationInput( std::vector legacy_input_vars; // return empty OpResult if this arg is optional and not shown in OpDesc - // TODO(lyk): HasInput doesnot consider variadic attribute if (op_desc.HasInput(legacy_input_name, true)) { legacy_input_vars = op_desc.Input(legacy_input_name, true); } @@ -436,6 +431,10 @@ std::vector OpTranscriber::GenerateOperationInput( // if src type is Tensor if (!is_vector) { + IR_ENFORCE(legacy_input_vars.size() == 1u, + "Input %s not found when parsing op %s", + info.name, + op_desc.Type()); auto defining_info = (*param_map)[legacy_input_vars[0]]; op_inputs.push_back(defining_info.value); diff --git a/paddle/fluid/ir_adaptor/translator/program_translator.cc b/paddle/fluid/ir_adaptor/translator/program_translator.cc index 14b904c551ce6513447769b89a509c345052649f..4e9775f7e9f8335ae2709720974524041379085b 100644 --- a/paddle/fluid/ir_adaptor/translator/program_translator.cc +++ b/paddle/fluid/ir_adaptor/translator/program_translator.cc @@ -24,6 +24,7 @@ #include "paddle/fluid/ir_adaptor/translator/type_translator.h" #include "paddle/ir/core/attribute.h" #include "paddle/ir/core/block.h" +#include "paddle/ir/core/builtin_attribute.h" #include "paddle/ir/core/builtin_op.h" #include "paddle/ir/core/builtin_type.h" #include "paddle/ir/core/enforce.h" @@ -38,17 +39,19 @@ using ProgramDesc = ::paddle::framework::ProgramDesc; using BlockDesc = ::paddle::framework::BlockDesc; using VarDesc = ::paddle::framework::VarDesc; +const std::unordered_set ProgramTranslator::no_cast_var_names = { + "feed", + "fetch", +}; + +constexpr char kAttrStopGradients[] = "stop_gradient"; + ProgramTranslator::ProgramTranslator(const ProgramDesc* legacy_program, ir::Program* program) : legacy_program_(legacy_program), program_(program) { ctx_ = ir::IrContext::Instance(); } -const std::unordered_set ProgramTranslator::no_cast_var_names = { - "feed", - "fetch", -}; - void ProgramTranslator::Translate() { PADDLE_ENFORCE_EQ( legacy_program_->Size(), @@ -71,6 +74,11 @@ void ProgramTranslator::Translate() { const BlockDesc& block = legacy_program_->Block(block_idx); SetParameterFromSingleBlock(block); } + + for (size_t block_idx = 0; block_idx < legacy_program_->Size(); block_idx++) { + const BlockDesc& block = legacy_program_->Block(block_idx); + SetStopGradientAttributeForAllValue(block); + } } inline ir::Operation* InsertGetParamaterOp(ir::IrContext* ctx, @@ -198,5 +206,35 @@ void ProgramTranslator::SetParameterFromSingleBlock(const BlockDesc& block) { } } +void ProgramTranslator::SetStopGradientAttributeForAllValue( + const BlockDesc& block) { + // Currently we set stop gradient for operation that generated a value + // connected with VarDesc + for (const auto& [var_name, value_info] : param_map_) { + VLOG(10) << "[op translated][stop gradient]" << var_name; + VarDesc* var = block.FindVarRecursive(var_name); + if (var == nullptr) { + continue; + } + ir::OpResult value = value_info.value; + auto* defining_op = value.owner(); + VLOG(8) << "[op translated][stop gradient]" << var_name + << " from: " << defining_op->name(); + std::vector stop_gradients; + if (defining_op->HasAttribute(kAttrStopGradients)) { + stop_gradients = defining_op->attribute(kAttrStopGradients) + .dyn_cast() + .data(); + } else { + stop_gradients = std::vector( + defining_op->num_results(), ir::BoolAttribute::get(ctx_, false)); + } + stop_gradients[value.GetResultIndex()] = + ir::BoolAttribute::get(ctx_, var->StopGradient()); + defining_op->set_attribute(kAttrStopGradients, + ir::ArrayAttribute::get(ctx_, stop_gradients)); + } +} + } // namespace translator } // namespace paddle diff --git a/paddle/fluid/ir_adaptor/translator/program_translator.h b/paddle/fluid/ir_adaptor/translator/program_translator.h index ec479aad730b195155aeb946c1f8e20883b2e22d..ce34fec141912e6acc60afb020c7418bb24ffab6 100644 --- a/paddle/fluid/ir_adaptor/translator/program_translator.h +++ b/paddle/fluid/ir_adaptor/translator/program_translator.h @@ -72,12 +72,13 @@ class ProgramTranslator { /// 2. "fetch", the output variable of fetch op /// However, new feed has no input and new fetch has no output /// So we don't handle these two vairables when - /// `ExtractParameterFromSingleBlock` + /// `Get/SetParameterFromSingleBlock` static const std::unordered_set no_cast_var_names; void GetParameterForSingleBlock(const BlockDesc& block); void InsertOperationToSingleBlock(const BlockDesc& block); void SetParameterFromSingleBlock(const BlockDesc& block); + void SetStopGradientAttributeForAllValue(const BlockDesc& block); }; } // namespace translator diff --git a/paddle/ir/core/operation.cc b/paddle/ir/core/operation.cc index 5348ef81ef962d87385eca5c2aaea3085aaef898..3600f9a55dd10f4863f5866cd1a3d0523fa78ecd 100644 --- a/paddle/ir/core/operation.cc +++ b/paddle/ir/core/operation.cc @@ -205,6 +205,11 @@ std::string Operation::name() const { return p_name ? p_name : ""; } +Attribute Operation::attribute(const std::string &key) const { + IR_ENFORCE(HasAttribute(key), "operation(%s): no attribute %s", name(), key); + return attributes_.at(key); +} + Region *Operation::GetParentRegion() const { return parent_ ? parent_->GetParent() : nullptr; } diff --git a/paddle/ir/core/operation.h b/paddle/ir/core/operation.h index 654674869b88b99c83c6afc0f01b7f08c80b0b2c..711434220ef75b5cff982108b3ea939e9611de28 100644 --- a/paddle/ir/core/operation.h +++ b/paddle/ir/core/operation.h @@ -64,10 +64,16 @@ class IR_API alignas(8) Operation final { const AttributeMap &attributes() const { return attributes_; } - void SetAttribute(const std::string &key, Attribute value) { + void set_attribute(const std::string &key, Attribute value) { attributes_[key] = value; } + Attribute attribute(const std::string &key) const; + + bool HasAttribute(const std::string &key) const { + return attributes_.find(key) != attributes_.end(); + } + ir::OpInfo info() const { return info_; } uint32_t num_results() const { return num_results_; } diff --git a/test/cpp/ir/core/ir_op_test.cc b/test/cpp/ir/core/ir_op_test.cc index 0e246af03cbe10360b3cbe568ee19658f3282ae0..c7f9c5e8af2593d4c9bda1755b753fe554f4e014 100644 --- a/test/cpp/ir/core/ir_op_test.cc +++ b/test/cpp/ir/core/ir_op_test.cc @@ -274,6 +274,6 @@ TEST(op_test, module_op_death) { EXPECT_EQ(program.module_op().program(), &program); EXPECT_EQ(program.module_op().ir_context(), ctx); - program.module_op()->SetAttribute("program", - ir::PointerAttribute::get(ctx, &program)); + program.module_op()->set_attribute("program", + ir::PointerAttribute::get(ctx, &program)); }