未验证 提交 5d40f2a2 编写于 作者: K kangguangli 提交者: GitHub

[IR] refine program translator (#54719)

* refine program translator

* fix warning: not override

* fix bug

* merge new modifications

* modify by reviews

* resolve conflicts

* resolve conflicts

* fix

* fix

* fix conflicts

* add unittest for special op transcriber

* set cpu as default backend

* modify by reviews
上级 051e55c6
...@@ -1013,9 +1013,6 @@ class DygraphFunctionGeneratorBase(FunctionGeneratorBase): ...@@ -1013,9 +1013,6 @@ class DygraphFunctionGeneratorBase(FunctionGeneratorBase):
name in forward_outputs_position_map.keys() name in forward_outputs_position_map.keys()
), AssertMessage(name, forward_outputs_position_map.keys()) ), AssertMessage(name, forward_outputs_position_map.keys())
if is_optional:
set_tensor_wrappers = f"{indent}if({name}) grad_node->SetTensorWrapper{name}(*{name});"
else:
set_tensor_wrappers = ( set_tensor_wrappers = (
f"{indent}grad_node->SetTensorWrapper{name}({name});" f"{indent}grad_node->SetTensorWrapper{name}({name});"
) )
......
...@@ -35,3 +35,128 @@ ...@@ -35,3 +35,128 @@
force_backend: null force_backend: null
inplace: null inplace: null
backward: null backward: null
- name: share_buffer_
inputs:
- typename: Tensor[]
name: x
optional: false
no_need_buffer: false
data_transform: {}
attrs:
- {typename: 'bool[]', name: share_dims_and_dtype, default_value: '{}'}
outputs:
- {typename: 'Tensor[]', name: out, size: x.size(), optional: false, intermediate: false}
- {typename: 'Tensor[]', name: xout, size: x.size(), optional: false, intermediate: false}
no_need_buffer: null
data_transform: null
inplace: null
backward: null
- name: assert
inputs:
- typename: Tensor
name: cond
optional: false
no_need_buffer: false
data_transform: {}
- typename: Tensor[]
name: data
optional: false
no_need_buffer: false
data_transform: {}
attrs:
- {typename: 'int64_t', name: summarize, default_value: '-1'}
outputs: []
no_need_buffer: null
data_transform: null
inplace: null
backward: null
- name: print
inputs:
- typename: Tensor
name: in
optional: false
no_need_buffer: false
data_transform: {}
attrs:
- {typename: 'int', name: first_n}
- {typename: 'str', name: message}
- {typename: 'int', name: summarize}
- {typename: 'bool', name: print_tensor_name, default_value: 'true'}
- {typename: 'bool', name: print_tensor_type, default_value: 'true'}
- {typename: 'bool', name: print_tensor_shape, default_value: 'true'}
- {typename: 'bool', name: print_tensor_layout, default_value: 'true'}
- {typename: 'bool', name: print_tensor_lod, default_value: 'true'}
- {typename: 'str', name: print_phase, default_value: 'BOTH'}
- {typename: 'bool', name: is_forward, default_value: 'true'}
outputs:
- typename: Tensor
name: out
optional: false
no_need_buffer: false
data_transform: {}
no_need_buffer: null
data_transform: null
inplace: null
backward: null
- name: add_n_
inputs:
- typename: Tensor[]
name: inputs
optional: false
no_need_buffer: false
data_transform: {}
attrs: []
outputs:
- {typename: Tensor, name: out, optional: false, intermediate: false}
no_need_buffer: null
data_transform: null
invoke: {func: add_n_impl, args: inputs}
backward: add_n_grad
- name: write_to_array
inputs:
- typename: Tensor
name: i
optional: false
no_need_buffer: false
data_transform: {}
- typename: Tensor
name: x
optional: false
no_need_buffer: false
data_transform: {}
attrs: []
outputs:
- {typename: 'Tensor[]', name: out, optional: false, intermediate: false}
no_need_buffer: null
data_transform: null
backward: write_to_array_grad
- name: lod_array_length
inputs:
- typename: Tensor[]
name: x
optional: false
no_need_buffer: false
data_transform: {}
attrs: []
outputs:
- {typename: 'Tensor', name: out, optional: false, intermediate: false}
no_need_buffer: null
data_transform: null
- name: py_func_
inputs:
- {typename: 'Tensor', name: x, optional: false, no_need_buffer: false, data_transform: {}}
attrs:
- {typename: 'int', name: forward_callable_id, default_value: '0'}
- {typename: 'int', name: backward_callable_id, default_value: '-1'}
- {typename: 'str[]', name: backward_skip_vars, default_value: '{}'}
outputs:
- {typename: 'Tensor', name: out, optional: false, intermediate: false}
no_need_buffer: null
data_transform: null
...@@ -173,6 +173,10 @@ phi::KernelKey GetKernelKey( ...@@ -173,6 +173,10 @@ phi::KernelKey GetKernelKey(
} }
} }
if (kernel_backend == phi::Backend::UNDEFINED) {
kernel_backend = paddle::experimental::ParseBackend(place);
}
phi::KernelKey res(kernel_backend, kernel_layout, kernel_data_type); phi::KernelKey res(kernel_backend, kernel_layout, kernel_data_type);
return res; return res;
} }
......
...@@ -42,6 +42,11 @@ class AttributeVisitor { ...@@ -42,6 +42,11 @@ class AttributeVisitor {
return ir::Int32Attribute::get(ctx, i); return ir::Int32Attribute::get(ctx, i);
} }
virtual ir::Attribute operator()(int64_t i) {
VLOG(10) << "translating int";
return ir::Int64Attribute::get(ctx, i);
}
virtual ir::Attribute operator()(float f) { virtual ir::Attribute operator()(float f) {
VLOG(10) << "translating float"; VLOG(10) << "translating float";
return ir::FloatAttribute::get(ctx, f); return ir::FloatAttribute::get(ctx, f);
...@@ -146,6 +151,21 @@ class AttributeVisitor { ...@@ -146,6 +151,21 @@ class AttributeVisitor {
} }
}; };
class Int64ArrayAttributeVisitor : public AttributeVisitor {
public:
using AttributeVisitor::AttributeVisitor;
ir::Attribute operator()(const std::vector<int>& is) override {
VLOG(10) << "translating vector<int64>";
std::vector<ir::Attribute> attrs;
attrs.reserve(is.size());
for (const auto& v : is) {
attrs.push_back(ir::Int64Attribute::get(ctx, v));
}
return ir::ArrayAttribute::get(ctx, attrs);
}
};
class IntArrayAttributeVisitor : public AttributeVisitor { class IntArrayAttributeVisitor : public AttributeVisitor {
public: public:
using AttributeVisitor::AttributeVisitor; using AttributeVisitor::AttributeVisitor;
...@@ -171,6 +191,11 @@ class DataTypeAttributeVisitor : public AttributeVisitor { ...@@ -171,6 +191,11 @@ class DataTypeAttributeVisitor : public AttributeVisitor {
auto phi_dtype = phi::TransToPhiDataType(i); auto phi_dtype = phi::TransToPhiDataType(i);
return paddle::dialect::DataTypeAttribute::get(ctx, phi_dtype); return paddle::dialect::DataTypeAttribute::get(ctx, phi_dtype);
} }
ir::Attribute operator()(const paddle::blank& blank) override {
VLOG(10) << "translating paddle::blank to DataType::UNDEFINED";
return paddle::dialect::DataTypeAttribute::get(ctx, phi::DataType());
}
}; };
class PlaceAttributeVisitor : public AttributeVisitor { class PlaceAttributeVisitor : public AttributeVisitor {
...@@ -178,8 +203,8 @@ class PlaceAttributeVisitor : public AttributeVisitor { ...@@ -178,8 +203,8 @@ class PlaceAttributeVisitor : public AttributeVisitor {
using AttributeVisitor::AttributeVisitor; using AttributeVisitor::AttributeVisitor;
ir::Attribute operator()(const paddle::blank& blank) override { ir::Attribute operator()(const paddle::blank& blank) override {
VLOG(10) << "translating paddle::blank"; VLOG(10) << "translating paddle::blank to Place::UNDEFINED";
phi::Place data(phi::AllocationType::CPU); phi::Place data(phi::AllocationType::UNDEFINED);
return paddle::dialect::PlaceAttribute::get(ctx, data); return paddle::dialect::PlaceAttribute::get(ctx, data);
} }
}; };
...@@ -192,6 +217,8 @@ AttributeTranslator::AttributeTranslator() { ...@@ -192,6 +217,8 @@ AttributeTranslator::AttributeTranslator() {
new DataTypeAttributeVisitor(); new DataTypeAttributeVisitor();
special_visitors["paddle::dialect::PlaceAttribute"] = special_visitors["paddle::dialect::PlaceAttribute"] =
new PlaceAttributeVisitor(); new PlaceAttributeVisitor();
special_visitors["ir::ArrayAttribute<ir::Int64Attribute>"] =
new Int64ArrayAttributeVisitor();
} }
ir::Attribute AttributeTranslator::operator()( ir::Attribute AttributeTranslator::operator()(
......
...@@ -36,7 +36,7 @@ class OpTranslator { ...@@ -36,7 +36,7 @@ class OpTranslator {
using BlockDesc = paddle::framework::BlockDesc; using BlockDesc = paddle::framework::BlockDesc;
using VarDesc = paddle::framework::VarDesc; using VarDesc = paddle::framework::VarDesc;
using OpTranslateFn = std::function<ir::Operation*( using OpTranslateFn = std::function<ir::Operation*(
ir::IrContext*, TranslationContext*, ir::Program*, const OpDesc&)>; ir::IrContext*, TranslationContext*, const OpDesc&, ir::Program*)>;
private: private:
OpTranslator(); // Disallow instantiation outside of the class. OpTranslator(); // Disallow instantiation outside of the class.
......
...@@ -111,25 +111,46 @@ void ProgramTranslator::GetParameterForSingleBlock(const BlockDesc& block) { ...@@ -111,25 +111,46 @@ void ProgramTranslator::GetParameterForSingleBlock(const BlockDesc& block) {
parameter_name_mappings_[var->Name()] = var; parameter_name_mappings_[var->Name()] = var;
} }
std::unordered_set<std::string> inner_defining_variables;
for (auto op_desc : block.AllOps()) { for (auto op_desc : block.AllOps()) {
for (const auto& n : op_desc->Inputs()) { for (const auto& n : op_desc->Inputs()) {
const auto& input_var_names = n.second; const auto& input_var_names = n.second;
for (const auto& var_name : input_var_names) { for (const auto& var_name : input_var_names) {
bool need_get_parameter_op = (parameter_name_mappings_.find(var_name) != if (no_cast_var_names.count(var_name) != 0) continue;
VarDesc* var_desc = nullptr;
bool is_parameter = (parameter_name_mappings_.find(var_name) !=
parameter_name_mappings_.end()); parameter_name_mappings_.end());
need_get_parameter_op &= (parameter_visited_.count(var_name) == 0); is_parameter &= (parameter_visited_.count(var_name) == 0);
if (is_parameter) {
var_desc = parameter_name_mappings_[var_name];
}
bool is_unseen_variable =
(inner_defining_variables.count(var_name) == 0);
if (is_unseen_variable) {
var_desc = block.FindVarRecursive(var_name);
}
bool need_get_parameter_op = is_parameter || is_unseen_variable;
if (need_get_parameter_op) { if (need_get_parameter_op) {
ir::Operation* op = ir::Operation* op = InsertGetParamaterOp(ctx_, var_desc);
InsertGetParamaterOp(ctx_, parameter_name_mappings_[var_name]);
program_->block()->push_back(op); program_->block()->push_back(op);
param_map_[var_name] = VariableDefiningInfo(op->result(0)); param_map_[var_name] = VariableDefiningInfo(op->result(0));
VLOG(10) << "[op translated][get parameter]" << op; VLOG(10) << "[op translated][get parameter]" << op;
program_->SetParameter(var_name, nullptr); program_->SetParameter(var_name, nullptr);
parameter_visited_.insert(var_name); parameter_visited_.insert(var_name);
inner_defining_variables.insert(var_name);
} }
} }
} }
for (const auto& n : op_desc->Outputs()) {
const auto& output_var_names = n.second;
for (const auto& var_name : output_var_names) {
inner_defining_variables.insert(var_name);
}
}
} }
} }
...@@ -137,7 +158,7 @@ void ProgramTranslator::InsertOperationToSingleBlock(const BlockDesc& block) { ...@@ -137,7 +158,7 @@ void ProgramTranslator::InsertOperationToSingleBlock(const BlockDesc& block) {
auto& op_translator = OpTranslator::instance(); auto& op_translator = OpTranslator::instance();
for (auto op : block.AllOps()) { for (auto op : block.AllOps()) {
OpTranslateFn& fn = op_translator[op->Type()]; OpTranslateFn& fn = op_translator[op->Type()];
ir::Operation* operation = fn(ctx_, &param_map_, program_, *op); ir::Operation* operation = fn(ctx_, &param_map_, *op, program_);
VLOG(10) << "[op translated][special]" << operation; VLOG(10) << "[op translated][special]" << operation;
} }
} }
......
...@@ -19,6 +19,7 @@ ...@@ -19,6 +19,7 @@
#include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/ir/dialect/pd_dialect.h" #include "paddle/fluid/ir/dialect/pd_dialect.h"
#include "paddle/fluid/ir_adaptor/translator/program_translator.h" #include "paddle/fluid/ir_adaptor/translator/program_translator.h"
#include "paddle/ir/core/builtin_dialect.h"
#include "paddle/ir/core/program.h" #include "paddle/ir/core/program.h"
namespace paddle { namespace paddle {
...@@ -28,7 +29,9 @@ using Program = ::ir::Program; ...@@ -28,7 +29,9 @@ using Program = ::ir::Program;
std::unique_ptr<Program> TranslateLegacyProgramToProgram( std::unique_ptr<Program> TranslateLegacyProgramToProgram(
const LegacyProgramDesc& legacy_program) { const LegacyProgramDesc& legacy_program) {
auto program = std::make_unique<Program>(ir::IrContext::Instance()); ir::IrContext* ctx = ir::IrContext::Instance();
ctx->GetOrRegisterDialect<dialect::PaddleDialect>();
auto program = std::make_unique<Program>(ctx);
translator::ProgramTranslator program_translator(&legacy_program, translator::ProgramTranslator program_translator(&legacy_program,
program.get()); program.get());
......
...@@ -93,6 +93,20 @@ TypeTranslator::TypeTranslator() { ...@@ -93,6 +93,20 @@ TypeTranslator::TypeTranslator() {
size_t offset = 0; size_t offset = 0;
return DenseTensorType::get(ctx, dtype, dim, layout, lod, offset); return DenseTensorType::get(ctx, dtype, dim, layout, lod, offset);
}}, }},
{VarType::LOD_TENSOR_ARRAY,
[&](ir::IrContext* ctx, const VarDesc& var_desc) -> ir::Type {
VLOG(10) << "[vartype translating]"
<< "[" << var_desc.Name() << "] from LOD_TENSOR_ARRAY";
return ir::VectorType::get(ctx, std::vector<ir::Type>{});
}},
{VarType::SELECTED_ROWS,
[&](ir::IrContext* ctx, const VarDesc& var_desc) -> ir::Type {
VLOG(10) << "[vartype translating]"
<< "[" << var_desc.Name() << "] from SELECTED_ROWS";
return this->operator[](VarType::LOD_TENSOR)(ctx, var_desc);
}},
}; };
} }
......
...@@ -84,6 +84,7 @@ ...@@ -84,6 +84,7 @@
kernel : kernel :
func : assign func : assign
backward : assign_grad backward : assign_grad
inplace : (x -> out)
- op : assign_out_ - op : assign_out_
args : (Tensor x, Tensor output) args : (Tensor x, Tensor output)
...@@ -120,6 +121,7 @@ ...@@ -120,6 +121,7 @@
data_type : x data_type : x
view : (mean -> mean_out), (variance -> variance_out) view : (mean -> mean_out), (variance -> variance_out)
backward : batch_norm_grad backward : batch_norm_grad
optional : reserve_space
- op : cast - op : cast
args : (Tensor x, DataType dtype) args : (Tensor x, DataType dtype)
......
...@@ -769,7 +769,7 @@ ...@@ -769,7 +769,7 @@
attrs : ['int[] slots = {}'] attrs : ['int[] slots = {}']
- op : divide (elementwise_div) - op : divide (elementwise_div)
backward : divide_grad (elementwise_div) backward : divide_grad (elementwise_div_grad)
inputs : inputs :
{x: X, y : Y} {x: X, y : Y}
outputs : outputs :
...@@ -1776,6 +1776,8 @@ ...@@ -1776,6 +1776,8 @@
- op : mish - op : mish
backward : mish_grad backward : mish_grad
inputs:
lambda: threshold
extra : extra :
attrs : [bool use_mkldnn = false] attrs : [bool use_mkldnn = false]
...@@ -2839,6 +2841,8 @@ ...@@ -2839,6 +2841,8 @@
yolo_loss : GetYoloLossExpectedKernelType yolo_loss : GetYoloLossExpectedKernelType
yolo_loss_grad : GetYoloLossExpectedKernelType yolo_loss_grad : GetYoloLossExpectedKernelType
- op: full_batch_size_like (fill_constant_batch_size_like)
- op: lu - op: lu
backward: lu_grad backward: lu_grad
inputs: inputs:
......
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import paddle
from paddle.fluid import core
paddle.enable_static()
class TestCastOpTranscriber(unittest.TestCase):
def test_op(self):
place = core.Place()
place.set_place(paddle.CPUPlace())
new_scope = paddle.static.Scope()
main_program = paddle.static.Program()
with paddle.static.scope_guard(new_scope):
with paddle.static.program_guard(main_program):
x = paddle.to_tensor([2, 3, 4], 'float64')
y = paddle.cast(x, 'uint8')
default_job = core.Job("default")
type_to_program = {"default": main_program.desc}
plan = core.Plan([default_job], type_to_program)
new_exe = core.StandaloneExecutor(place, plan, new_scope)
class TestEmbeddingOpTranscriber(unittest.TestCase):
def test_op(self):
place = core.Place()
place.set_place(paddle.CPUPlace())
new_scope = paddle.static.Scope()
main_program = paddle.static.Program()
with paddle.static.scope_guard(new_scope):
with paddle.static.program_guard(main_program):
x = paddle.static.data(name="x", shape=[2, 4], dtype=np.int64)
embedding = paddle.nn.Embedding(
10, 3, weight_attr=paddle.nn.initializer.Constant(value=1.0)
)
output = embedding(x)
default_job = core.Job("default")
type_to_program = {"default": main_program.desc}
plan = core.Plan([default_job], type_to_program)
new_exe = core.StandaloneExecutor(place, plan, new_scope)
class TestIncrementOpTranscriber(unittest.TestCase):
def test_op(self):
place = core.Place()
place.set_place(paddle.CPUPlace())
new_scope = paddle.static.Scope()
main_program = paddle.static.Program()
with paddle.static.scope_guard(new_scope):
with paddle.static.program_guard(main_program):
data = paddle.zeros(shape=[1], dtype='float32')
counter = paddle.increment(data)
default_job = core.Job("default")
type_to_program = {"default": main_program.desc}
plan = core.Plan([default_job], type_to_program)
new_exe = core.StandaloneExecutor(place, plan, new_scope)
class TestAssignValueOpTranscriber(unittest.TestCase):
def test_op(self):
place = core.Place()
place.set_place(paddle.CPUPlace())
new_scope = paddle.static.Scope()
main_program = paddle.static.Program()
with paddle.static.scope_guard(new_scope):
with paddle.static.program_guard(main_program):
x = paddle.to_tensor(
[[0.1, 0.2], [0.3, 0.4]],
place=paddle.CPUPlace(),
stop_gradient=False,
)
default_job = core.Job("default")
type_to_program = {"default": main_program.desc}
plan = core.Plan([default_job], type_to_program)
new_exe = core.StandaloneExecutor(place, plan, new_scope)
class TestRnnOpTranscriber(unittest.TestCase):
def test_op(self):
place = core.Place()
place.set_place(paddle.CPUPlace())
new_scope = paddle.static.Scope()
main_program = paddle.static.Program()
with paddle.static.scope_guard(new_scope):
with paddle.static.program_guard(main_program):
x = paddle.randn((4, 16))
prev_h = paddle.randn((4, 32))
cell = paddle.nn.SimpleRNNCell(16, 32)
y, h = cell(x, prev_h)
default_job = core.Job("default")
type_to_program = {"default": main_program.desc}
plan = core.Plan([default_job], type_to_program)
new_exe = core.StandaloneExecutor(place, plan, new_scope)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册