未验证 提交 7470889f 编写于 作者: K kangguangli 提交者: GitHub

[NewIR]support attribute hook and fix_reduce_all (#55553)

* support attribute hook and fix_reduce_all

* resolve merge conflicts

* fix coverage ci

* trigger CI

* trigger CI

* fix coverage ci
上级 84a56b4a
......@@ -60,12 +60,14 @@ using OpAttributeInfo = dialect::OpAttributeInfo;
using OpAttributeInfoList = std::vector<dialect::OpAttributeInfo>;
using OpOutputInfo = dialect::OpOutputInfo;
using OpOutputInfoList = std::vector<dialect::OpOutputInfo>;
using InputHandleFn = std::function<ir::OpResult(ir::IrContext*,
TranslationContext*,
const OpDesc&,
const std::string&,
const OpInputInfo&,
ir::Program*)>;
using InputHandlerFn = std::function<ir::OpResult(ir::IrContext*,
TranslationContext*,
const OpDesc&,
const std::string&,
const OpInputInfo&,
ir::Program*)>;
using AttributeHandlerFn = std::function<ir::Attribute(
ir::IrContext*, const OpDesc&, const OpAttributeInfo&)>;
constexpr char kTargetDialectPrefix[] = "pd.";
constexpr char kEmptyVarName[] = "@EMPTY@";
......@@ -291,7 +293,12 @@ struct OpTranscriber {
const OpOutputMapping& arg_to_idx);
public:
virtual InputHandleFn GetSpecialInputHandlers(std::string input_name) {
virtual InputHandlerFn GetSpecialInputHandlers(
const std::string& input_name) {
return nullptr;
}
virtual AttributeHandlerFn GetSpecialAttributeHandlers(
const std::string& input_name) {
return nullptr;
}
};
......@@ -558,6 +565,12 @@ ir::AttributeMap OpTranscriber::TranslateOpAttribute(
ir::AttributeMap attribute_map = {};
for (const auto& info : op_attr_infos) {
if (auto handler = this->GetSpecialAttributeHandlers(info.name)) {
auto new_attr = handler(ctx, op_desc, info);
attribute_map[info.name] = new_attr;
continue;
}
auto legacy_attr_name =
op_normalizer.GetLegacyAttrName(op_desc.Type(), info.name);
VLOG(10) << "[op: " << op_desc.Type()
......@@ -885,7 +898,8 @@ ir::OpResult TranslateDropOutStateIn(ir::IrContext* ctx,
// `rnn` has an aditional input in dynamic graph
struct RnnOpTranscriber : public OpTranscriber {
InputHandleFn GetSpecialInputHandlers(std::string input_name) override {
InputHandlerFn GetSpecialInputHandlers(
const std::string& input_name) override {
if (input_name != "dropout_state_in") {
return nullptr;
}
......@@ -1207,7 +1221,8 @@ ir::OpResult TranslateNumClassesForOneHot(ir::IrContext* ctx,
}
struct OneHotTranscriber : public OpTranscriber {
InputHandleFn GetSpecialInputHandlers(std::string input_name) override {
InputHandlerFn GetSpecialInputHandlers(
const std::string& input_name) override {
if (input_name != "num_classes") {
return nullptr;
}
......@@ -1215,21 +1230,53 @@ struct OneHotTranscriber : public OpTranscriber {
};
};
ir::Attribute TranslateReduceAll(ir::IrContext* ctx,
const OpDesc& op_desc,
const OpAttributeInfo& attr_info) {
bool reduce_all = false;
if (op_desc.HasAttr("reduce_all")) {
reduce_all = paddle::get<bool>(op_desc.GetAttr("reduce_all"));
}
if (reduce_all) {
return ir::ArrayAttribute::get(ctx, std::vector<ir::Attribute>{});
}
auto& attribute_translator = AttributeTranslator::instance();
auto& op_normalizer = OpNameNormalizer::instance();
auto legacy_attr_name =
op_normalizer.GetLegacyAttrName(op_desc.Type(), attr_info.name);
paddle::framework::Attribute dims = op_desc.GetAttr(legacy_attr_name);
return attribute_translator(attr_info.type_name, dims);
}
struct ReduceOpTranscriber : public OpTranscriber {
AttributeHandlerFn GetSpecialAttributeHandlers(
const std::string& input_name) override {
if (input_name != "axis") {
return nullptr;
}
return TranslateReduceAll;
}
};
OpTranslator::OpTranslator() {
general_handler = OpTranscriber();
special_handlers["add_n"] = AddNOpTranscriber();
special_handlers["assign_value"] = AssignValueOpTranscriber();
special_handlers["cast"] = CastOpTranscriber();
special_handlers["feed"] = FeedOpTranscriber();
special_handlers["feed_with_place"] = FeedWithPlaceOpTranscriber();
special_handlers["fetch_v2"] = FetchOpTranscriber();
special_handlers["cast"] = CastOpTranscriber();
special_handlers["split"] = SplitOpTranscriber();
special_handlers["increment"] = IncrementOpTranscriber();
special_handlers["lookup_table_v2"] = EmbeddingOpTranscriber();
special_handlers["lookup_table_v2_grad"] = EmbeddingGradOpTranscriber();
special_handlers["assign_value"] = AssignValueOpTranscriber();
special_handlers["increment"] = IncrementOpTranscriber();
special_handlers["one_hot_v2"] = OneHotTranscriber();
special_handlers["reduce_all"] = ReduceOpTranscriber();
special_handlers["reduce_any"] = ReduceOpTranscriber();
special_handlers["rnn"] = RnnOpTranscriber();
special_handlers["shaddow_output"] = ShaddowOutputOpTranscriber();
special_handlers["one_hot_v2"] = OneHotTranscriber();
special_handlers["add_n"] = AddNOpTranscriber();
special_handlers["split"] = SplitOpTranscriber();
special_handlers["sum"] = AddNOpTranscriber();
}
......
......@@ -160,5 +160,39 @@ class TestOneHotOpTranscriber(unittest.TestCase):
_ = ir.translate_to_new_ir(main_program.desc)
class TestReduceOpTranscriber(unittest.TestCase):
def test_reduce_all(self):
place = core.Place()
place.set_place(paddle.CPUPlace())
exe = paddle.static.Executor(place)
new_scope = paddle.static.Scope()
main_program = paddle.static.Program()
with paddle.static.scope_guard(new_scope):
with paddle.static.program_guard(main_program):
arr = np.ones([2, 2], dtype="float32")
x = paddle.to_tensor(arr, dtype='int32')
out1 = paddle.all(x)
out = exe.run(main_program, {}, fetch_list=[out1.name])
np.testing.assert_array_equal(out[0], np.all(arr))
def test_with_axis(self):
place = core.Place()
place.set_place(paddle.CPUPlace())
exe = paddle.static.Executor(place)
new_scope = paddle.static.Scope()
main_program = paddle.static.Program()
with paddle.static.scope_guard(new_scope):
with paddle.static.program_guard(main_program):
arr = np.ones([2, 2], dtype="float32")
x = paddle.to_tensor(arr, dtype='int32')
out1 = paddle.all(x, axis=0)
out = exe.run(main_program, {}, fetch_list=[out1.name])
np.testing.assert_array_equal(out[0], np.all(arr, axis=0))
if __name__ == "__main__":
unittest.main()
......@@ -139,6 +139,7 @@ test_prior_box_op
test_psroi_pool_op
test_put_along_axis_op
test_range
test_reduce_op
test_reverse_op
test_roi_align_op
test_roi_pool_op
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册