未验证 提交 8cf4b1c2 编写于 作者: Y YuanRisheng 提交者: GitHub

[NewIR]Polish IR code (#56087)

* perfect code

* delete __all__
上级 9e6f4433
...@@ -65,7 +65,7 @@ def GenBuildInputArgsStr( ...@@ -65,7 +65,7 @@ def GenBuildInputArgsStr(
] ]
if ( if (
op_attribute_build_arg_type_list[attr_idx] op_attribute_build_arg_type_list[attr_idx]
!= "std::string" != "const std::string&"
): ):
if ( if (
default_value[0] == "'" default_value[0] == "'"
...@@ -106,7 +106,7 @@ def GenBuildInputArgsStr( ...@@ -106,7 +106,7 @@ def GenBuildInputArgsStr(
op_non_mutable_attribute_build_arg_type_list[ op_non_mutable_attribute_build_arg_type_list[
attr_idx attr_idx
] ]
!= "std::string" != "const std::string&"
): ):
if ( if (
default_value[0] == "'" default_value[0] == "'"
......
...@@ -258,12 +258,12 @@ class OpInfoParser: ...@@ -258,12 +258,12 @@ class OpInfoParser:
'ir::ArrayAttribute<ir::BoolAttribute>', 'ir::ArrayAttribute<ir::BoolAttribute>',
'const std::vecot<bool>&', 'const std::vecot<bool>&',
], ],
'str': ['ir::StrAttribute', 'std::string'], 'str': ['ir::StrAttribute', 'const std::string&'],
'str[]': [ 'str[]': [
'ir::ArrayAttribute<ir::StrAttribute>', 'ir::ArrayAttribute<ir::StrAttribute>',
'const std::vector<std::string>&', 'const std::vector<std::string>&',
], ],
'Place': ['paddle::dialect::PlaceAttribute', 'Place'], 'Place': ['paddle::dialect::PlaceAttribute', 'const Place&'],
'DataLayout': [ 'DataLayout': [
'paddle::dialect::DataLayoutAttribute', 'paddle::dialect::DataLayoutAttribute',
'DataLayout', 'DataLayout',
...@@ -577,7 +577,7 @@ class OpInfoParser: ...@@ -577,7 +577,7 @@ class OpInfoParser:
temp_type = attribute_info['data_type'] temp_type = attribute_info['data_type']
if 'IntArray' in temp_type: if 'IntArray' in temp_type:
if 'data_type' in attribute_info: if 'data_type' in attribute_info:
temp_type = attribute_info['data_type'] temp_type = "const " + attribute_info['data_type'] + "&"
type_list.append(self.get_phi_dtype_name(temp_type)) type_list.append(self.get_phi_dtype_name(temp_type))
return type_list return type_list
......
...@@ -29,7 +29,9 @@ ir::OpResult add_n(std::vector<ir::OpResult> x) { ...@@ -29,7 +29,9 @@ ir::OpResult add_n(std::vector<ir::OpResult> x) {
return add_n_op.out(); return add_n_op.out();
} }
ir::OpResult mean(ir::OpResult x, std::vector<int64_t> axis, bool keepdim) { ir::OpResult mean(ir::OpResult x,
const std::vector<int64_t>& axis,
bool keepdim) {
paddle::dialect::MeanOp mean_op = paddle::dialect::MeanOp mean_op =
APIBuilder::Instance().GetBuilder()->Build<paddle::dialect::MeanOp>( APIBuilder::Instance().GetBuilder()->Build<paddle::dialect::MeanOp>(
x, axis, keepdim); x, axis, keepdim);
...@@ -37,27 +39,27 @@ ir::OpResult mean(ir::OpResult x, std::vector<int64_t> axis, bool keepdim) { ...@@ -37,27 +39,27 @@ ir::OpResult mean(ir::OpResult x, std::vector<int64_t> axis, bool keepdim) {
} }
ir::OpResult sum(ir::OpResult x, ir::OpResult sum(ir::OpResult x,
std::vector<int64_t> axis, const std::vector<int64_t>& axis,
phi::DataType dtype, phi::DataType dtype,
bool keepdim) { bool keepdim) {
paddle::dialect::SumOp sum_op = auto sum_op =
APIBuilder::Instance().GetBuilder()->Build<paddle::dialect::SumOp>( APIBuilder::Instance().GetBuilder()->Build<paddle::dialect::SumOp>(
x, axis, dtype, keepdim); x, axis, dtype, keepdim);
return sum_op.out(); return sum_op.out();
} }
ir::OpResult divide(ir::OpResult x, ir::OpResult y) { ir::OpResult divide(ir::OpResult x, ir::OpResult y) {
paddle::dialect::DivideOp divide_op = auto divide_op =
APIBuilder::Instance().GetBuilder()->Build<paddle::dialect::DivideOp>(x, APIBuilder::Instance().GetBuilder()->Build<paddle::dialect::DivideOp>(x,
y); y);
return divide_op.out(); return divide_op.out();
} }
ir::OpResult full(std::vector<int64_t> shape, ir::OpResult full(const std::vector<int64_t>& shape,
float value, float value,
phi::DataType dtype, phi::DataType dtype,
phi::Place place) { const phi::Place& place) {
paddle::dialect::FullOp full_op = auto full_op =
APIBuilder::Instance().GetBuilder()->Build<paddle::dialect::FullOp>( APIBuilder::Instance().GetBuilder()->Build<paddle::dialect::FullOp>(
shape, value, dtype, place); shape, value, dtype, place);
return full_op.out(); return full_op.out();
......
...@@ -25,20 +25,20 @@ namespace dialect { ...@@ -25,20 +25,20 @@ namespace dialect {
ir::OpResult add_n(std::vector<ir::OpResult> x); ir::OpResult add_n(std::vector<ir::OpResult> x);
ir::OpResult mean(ir::OpResult x, ir::OpResult mean(ir::OpResult x,
std::vector<int64_t> axis = {}, const std::vector<int64_t>& axis = {},
bool keepdim = false); bool keepdim = false);
ir::OpResult sum(ir::OpResult x, ir::OpResult sum(ir::OpResult x,
std::vector<int64_t> axis = {}, const std::vector<int64_t>& axis = {},
phi::DataType dtype = phi::DataType::UNDEFINED, phi::DataType dtype = phi::DataType::UNDEFINED,
bool keepdim = false); bool keepdim = false);
ir::OpResult divide(ir::OpResult x, ir::OpResult y); ir::OpResult divide(ir::OpResult x, ir::OpResult y);
ir::OpResult full(std::vector<int64_t> shape, ir::OpResult full(const std::vector<int64_t>& shape,
float value, float value,
phi::DataType dtype = phi::DataType::FLOAT32, phi::DataType dtype = phi::DataType::FLOAT32,
phi::Place place = phi::CPUPlace()); const phi::Place& place = phi::CPUPlace());
ir::OpResult tanh_grad(ir::OpResult out, ir::OpResult grad_out); ir::OpResult tanh_grad(ir::OpResult out, ir::OpResult grad_out);
......
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/ir/dialect/utils.h"
namespace paddle {
namespace dialect {
enum class AttrType {
UNDEFINED = 0,
BOOL,
INT32,
INT64,
FLOAT,
DOUBLE,
ARRAY,
INT_ARRAY,
SCALAR,
DATA_TYPE,
DATA_LAYOUT,
PLACE,
STRING,
NUM_ATTR_TYPES,
};
static inline AttrType GetAttributeType(const ir::Attribute& attr) {
if (attr.isa<ir::BoolAttribute>()) {
return AttrType::BOOL;
} else if (attr.isa<ir::FloatAttribute>()) {
return AttrType::FLOAT;
} else if (attr.isa<ir::DoubleAttribute>()) {
return AttrType::DOUBLE;
} else if (attr.isa<ir::Int32Attribute>()) {
return AttrType::INT32;
} else if (attr.isa<ir::Int64Attribute>()) {
return AttrType::INT64;
} else if (attr.isa<ir::ArrayAttribute>()) {
return AttrType::ARRAY;
} else if (attr.isa<ir::StrAttribute>()) {
return AttrType::STRING;
} else if (attr.isa<paddle::dialect::IntArrayAttribute>()) {
return AttrType::INT_ARRAY;
} else if (attr.isa<paddle::dialect::DataTypeAttribute>()) {
return AttrType::DATA_TYPE;
} else if (attr.isa<paddle::dialect::PlaceAttribute>()) {
return AttrType::PLACE;
} else {
PADDLE_THROW(phi::errors::Unimplemented(
"Unsupported ir Attribute type when casting it into "
"AttrType."));
}
}
static std::unordered_map<AttrType,
std::function<VariantType(const ir::Attribute& attr)>>
kAttrCastMap = {
{AttrType::BOOL,
[](const ir::Attribute& attr) {
return VariantType{attr.dyn_cast<ir::BoolAttribute>().data()};
}},
{AttrType::FLOAT,
[](const ir::Attribute& attr) {
return VariantType{attr.dyn_cast<ir::FloatAttribute>().data()};
}},
{AttrType::DOUBLE,
[](const ir::Attribute& attr) {
return VariantType{attr.dyn_cast<ir::DoubleAttribute>().data()};
}},
{AttrType::INT32,
[](const ir::Attribute& attr) {
return VariantType{attr.dyn_cast<ir::Int32Attribute>().data()};
}},
{AttrType::INT64,
[](const ir::Attribute& attr) {
return VariantType{attr.dyn_cast<ir::Int64Attribute>().data()};
}},
{AttrType::INT_ARRAY,
[](const ir::Attribute& attr) {
return VariantType{
attr.dyn_cast<paddle::dialect::IntArrayAttribute>()
.data()
.GetData()};
}},
{AttrType::STRING,
[](const ir::Attribute& attr) {
return VariantType{attr.dyn_cast<ir::StrAttribute>().AsString()};
}},
{AttrType::DATA_TYPE,
[](const ir::Attribute& attr) {
return VariantType{
attr.dyn_cast<paddle::dialect::DataTypeAttribute>().data()};
}},
{AttrType::PLACE,
[](const ir::Attribute& attr) {
return VariantType{
attr.dyn_cast<paddle::dialect::PlaceAttribute>().data()};
}},
{AttrType::ARRAY,
[](const ir::Attribute& attr) {
auto attr_vec = attr.dyn_cast<ir::ArrayAttribute>().AsVector();
if (attr_vec.size() == 0) {
return VariantType{std::vector<int>()};
}
AttrType element_type = GetAttributeType(attr_vec[0]);
if (element_type == AttrType::BOOL) {
std::vector<bool> vec_bools;
for (auto vec_element : attr_vec) {
vec_bools.push_back(
vec_element.dyn_cast<ir::BoolAttribute>().data());
}
return VariantType{vec_bools};
} else if (element_type == AttrType::INT32) {
std::vector<int> vec_int32;
for (auto vec_element : attr_vec) {
vec_int32.push_back(
vec_element.dyn_cast<ir::Int32Attribute>().data());
}
return VariantType{vec_int32};
} else if (element_type == AttrType::INT64) {
std::vector<int64_t> vec_int64;
for (auto vec_element : attr_vec) {
vec_int64.push_back(
vec_element.dyn_cast<ir::Int64Attribute>().data());
}
return VariantType{vec_int64};
} else if (element_type == AttrType::FLOAT) {
std::vector<float> vec_float;
for (auto vec_element : attr_vec) {
vec_float.push_back(
vec_element.dyn_cast<ir::FloatAttribute>().data());
}
return VariantType{vec_float};
} else if (element_type == AttrType::DOUBLE) {
std::vector<double> vec_double;
for (auto vec_element : attr_vec) {
vec_double.push_back(
vec_element.dyn_cast<ir::DoubleAttribute>().data());
}
return VariantType{vec_double};
} else {
PADDLE_THROW(phi::errors::Unimplemented(
"Unsupported ir Attribute type when casting it into "
"vector."));
}
}},
};
VariantType GetAttributeData(const ir::Attribute& attr) {
AttrType attr_type = GetAttributeType(attr);
return kAttrCastMap[attr_type](attr);
}
} // namespace dialect
} // namespace paddle
...@@ -141,155 +141,7 @@ static inline ir::Attribute TransToIrAttribute(phi::Scalar scalar, ...@@ -141,155 +141,7 @@ static inline ir::Attribute TransToIrAttribute(phi::Scalar scalar,
} }
} }
enum class AttrType { VariantType GetAttributeData(const ir::Attribute& attr);
UNDEFINED = 0,
BOOL,
INT32,
INT64,
FLOAT,
DOUBLE,
ARRAY,
INT_ARRAY,
SCALAR,
DATA_TYPE,
DATA_LAYOUT,
PLACE,
STRING,
NUM_ATTR_TYPES,
};
static inline AttrType GetAttributeType(const ir::Attribute& attr) {
if (attr.isa<ir::BoolAttribute>()) {
return AttrType::BOOL;
} else if (attr.isa<ir::FloatAttribute>()) {
return AttrType::FLOAT;
} else if (attr.isa<ir::DoubleAttribute>()) {
return AttrType::DOUBLE;
} else if (attr.isa<ir::Int32Attribute>()) {
return AttrType::INT32;
} else if (attr.isa<ir::Int64Attribute>()) {
return AttrType::INT64;
} else if (attr.isa<ir::ArrayAttribute>()) {
return AttrType::ARRAY;
} else if (attr.isa<ir::StrAttribute>()) {
return AttrType::STRING;
} else if (attr.isa<paddle::dialect::IntArrayAttribute>()) {
return AttrType::INT_ARRAY;
} else if (attr.isa<paddle::dialect::DataTypeAttribute>()) {
return AttrType::DATA_TYPE;
} else if (attr.isa<paddle::dialect::PlaceAttribute>()) {
return AttrType::PLACE;
} else {
PADDLE_THROW(phi::errors::Unimplemented(
"Unsupported ir Attribute type when casting it into "
"AttrType."));
}
}
static std::unordered_map<AttrType,
std::function<VariantType(const ir::Attribute& attr)>>
attr_cast_map = {
{AttrType::BOOL,
[](const ir::Attribute& attr) {
return VariantType{attr.dyn_cast<ir::BoolAttribute>().data()};
}},
{AttrType::FLOAT,
[](const ir::Attribute& attr) {
return VariantType{attr.dyn_cast<ir::FloatAttribute>().data()};
}},
{AttrType::DOUBLE,
[](const ir::Attribute& attr) {
return VariantType{attr.dyn_cast<ir::DoubleAttribute>().data()};
}},
{AttrType::INT32,
[](const ir::Attribute& attr) {
return VariantType{attr.dyn_cast<ir::Int32Attribute>().data()};
}},
{AttrType::INT64,
[](const ir::Attribute& attr) {
return VariantType{attr.dyn_cast<ir::Int64Attribute>().data()};
}},
{AttrType::INT_ARRAY,
[](const ir::Attribute& attr) {
return VariantType{
attr.dyn_cast<paddle::dialect::IntArrayAttribute>()
.data()
.GetData()};
}},
{AttrType::STRING,
[](const ir::Attribute& attr) {
return VariantType{attr.dyn_cast<ir::StrAttribute>().AsString()};
}},
{AttrType::DATA_TYPE,
[](const ir::Attribute& attr) {
return VariantType{
attr.dyn_cast<paddle::dialect::DataTypeAttribute>().data()};
}},
{AttrType::PLACE,
[](const ir::Attribute& attr) {
return VariantType{
attr.dyn_cast<paddle::dialect::PlaceAttribute>().data()};
}},
{AttrType::ARRAY,
[](const ir::Attribute& attr) {
auto attr_vec = attr.dyn_cast<ir::ArrayAttribute>().AsVector();
if (attr_vec.size() == 0) {
return VariantType{std::vector<int>()};
}
AttrType element_type = GetAttributeType(attr_vec[0]);
if (element_type == AttrType::BOOL) {
std::vector<bool> vec_bools;
for (auto vec_element : attr_vec) {
vec_bools.push_back(
vec_element.dyn_cast<ir::BoolAttribute>().data());
}
return VariantType{vec_bools};
} else if (element_type == AttrType::INT32) {
std::vector<int> vec_int32;
for (auto vec_element : attr_vec) {
vec_int32.push_back(
vec_element.dyn_cast<ir::Int32Attribute>().data());
}
return VariantType{vec_int32};
} else if (element_type == AttrType::INT64) {
std::vector<int64_t> vec_int64;
for (auto vec_element : attr_vec) {
vec_int64.push_back(
vec_element.dyn_cast<ir::Int64Attribute>().data());
}
return VariantType{vec_int64};
} else if (element_type == AttrType::FLOAT) {
std::vector<float> vec_float;
for (auto vec_element : attr_vec) {
vec_float.push_back(
vec_element.dyn_cast<ir::FloatAttribute>().data());
}
return VariantType{vec_float};
} else if (element_type == AttrType::DOUBLE) {
std::vector<double> vec_double;
for (auto vec_element : attr_vec) {
vec_double.push_back(
vec_element.dyn_cast<ir::DoubleAttribute>().data());
}
return VariantType{vec_double};
} else {
PADDLE_THROW(phi::errors::Unimplemented(
"Unsupported ir Attribute type when casting it into "
"vector."));
}
}},
};
static inline VariantType GetAttributeData(const ir::Attribute& attr) {
AttrType attr_type = GetAttributeType(attr);
return attr_cast_map[attr_type](attr);
}
} // namespace dialect } // namespace dialect
} // namespace paddle } // namespace paddle
...@@ -57,7 +57,46 @@ PyTypeObject *g_ir_opresult_pytype = nullptr; ...@@ -57,7 +57,46 @@ PyTypeObject *g_ir_opresult_pytype = nullptr;
void BindOpsAPI(pybind11::module *module); void BindOpsAPI(pybind11::module *module);
void BindProgram(py::module *m) { void BindProgram(py::module *m) {
py::class_<Program> program(*m, "Program"); py::class_<Program> program(*m, "Program", R"DOC(
Create Python Program. Program is an abstraction of model structure, divided into
computational graphs and weights. The Program has a main block that stores the computational
graphs.
A set of Program usually contains startup program and main program.
A startup program is set to contain some initial work, eg. initialize the ``Parameter``, and the main
program will contain the network structure and vars for train.
A set of Program can be used for test or train, in train program ,
Paddle will contain all content to build a train network, in test
program Paddle will prune some content which is irrelevant to test, eg.
backward ops and vars.
**Notes**:
**we have** :ref:`api_paddle_static_default_startup_program` **and** :ref:`api_paddle_static_default_main_program`
**by default, a pair of them will shared the parameters. The** :ref:`api_paddle_static_default_startup_program` **only run once to initialize parameters,**
:ref:`api_paddle_static_default_main_program` **run in every mini batch and adjust the weights.**
Returns:
Program: An empty Program.
Examples:
.. code-block:: python
import paddle
import paddle.static as static
paddle.enable_static()
main_program = static.Program()
startup_program = static.Program()
with static.program_guard(main_program=main_program, startup_program=startup_program):
x = static.data(name="x", shape=[-1, 784], dtype='float32')
y = static.data(name="y", shape=[-1, 1], dtype='int32')
z = static.nn.fc(name="fc", x=x, size=10, activation="relu")
print("main program is: {}".format(main_program))
print("start up program is: {}".format(startup_program))
)DOC");
program program
.def( .def(
"__init__", "__init__",
...@@ -78,7 +117,13 @@ void BindProgram(py::module *m) { ...@@ -78,7 +117,13 @@ void BindProgram(py::module *m) {
} }
void BindBlock(py::module *m) { void BindBlock(py::module *m) {
py::class_<Block> block(*m, "Block"); py::class_<Block> block(*m, "Block", R"DOC(
In IR, a Block has a list of Operation and can represent a sub computational graph.
Notes:
The constructor of Block should not be invoked directly. You can
use `Program.block()` to get a block.
)DOC");
block.def("front", &Block::front, return_value_policy::reference) block.def("front", &Block::front, return_value_policy::reference)
.def("get_parent_program", .def("get_parent_program",
[](Block &self) { return self.GetParentOp()->GetParentProgram(); }) [](Block &self) { return self.GetParentOp()->GetParentProgram(); })
...@@ -91,14 +136,35 @@ void BindBlock(py::module *m) { ...@@ -91,14 +136,35 @@ void BindBlock(py::module *m) {
} }
return op_list; return op_list;
}) })
.def("remove_op", [](Block &self, Operation *op) { .def(
auto op_iter = std::find(self.begin(), self.end(), op); "remove_op",
self.erase(op_iter); [](Block &self, Operation *op) {
}); auto op_iter = std::find(self.begin(), self.end(), op);
self.erase(op_iter);
},
R"DOC(
Remove the specific position operator.
Args:
index(int): the position that the operator to insert.
Returns:
None
)DOC");
} }
void BindOperation(py::module *m) { void BindOperation(py::module *m) {
py::class_<Operation> op(*m, "Operation"); py::class_<Operation> op(*m, "Operation", R"DOC(
In IR, all the operation are represented by Operation, and Operation
is regarded as a build in an instruction of a Block. Users can call
python api to describe their neural network.
Notes:
The constructor of operator should not be invoked directly. Use
python api, for example: paddle.mean for building mean operation.
)DOC");
op.def("name", &Operation::name) op.def("name", &Operation::name)
.def("get_parent_block", .def("get_parent_block",
py::overload_cast<>(&Operation::GetParent), py::overload_cast<>(&Operation::GetParent),
...@@ -170,7 +236,15 @@ void BindOperation(py::module *m) { ...@@ -170,7 +236,15 @@ void BindOperation(py::module *m) {
} }
void BindValue(py::module *m) { void BindValue(py::module *m) {
py::class_<Value> value(*m, "Value"); py::class_<Value> value(*m, "Value", R"DOC(
Value class represents the SSA value in the IR system. It is a directed edge
and a base class.
Notes:
The constructor of Value should not be invoked directly. Value can be automatically constructed
when build network.
)DOC");
value value
.def("get_defining_op", .def("get_defining_op",
&Value::GetDefiningOp, &Value::GetDefiningOp,
...@@ -185,7 +259,16 @@ void BindValue(py::module *m) { ...@@ -185,7 +259,16 @@ void BindValue(py::module *m) {
} }
void BindOpOperand(py::module *m) { void BindOpOperand(py::module *m) {
py::class_<OpOperand> op_operand(*m, "OpOperand"); py::class_<OpOperand> op_operand(*m,
"OpOperand",
R"DOC(
OpOperand class represents the op_operand (input) of operation.
Notes:
The constructor of OpOperand should not be invoked directly. OpOperand can be automatically constructed
when build network.
)DOC");
op_operand op_operand
.def("source", .def("source",
[](OpOperand &self) { return self.source().dyn_cast<OpResult>(); }) [](OpOperand &self) { return self.source().dyn_cast<OpResult>(); })
...@@ -228,7 +311,13 @@ void SetStopGradient(const OpResult &self, bool stop_gradient) { ...@@ -228,7 +311,13 @@ void SetStopGradient(const OpResult &self, bool stop_gradient) {
} }
void BindOpResult(py::module *m) { void BindOpResult(py::module *m) {
py::class_<OpResult> op_result(*m, "OpResult"); py::class_<OpResult> op_result(*m, "OpResult", R"DOC(
OpResult class represents the value(output) defined by a result of operation.
Notes:
The constructor of OpResult should not be invoked directly. OpResult can be automatically constructed
when build network.
)DOC");
g_ir_opresult_pytype = reinterpret_cast<PyTypeObject *>(op_result.ptr()); g_ir_opresult_pytype = reinterpret_cast<PyTypeObject *>(op_result.ptr());
op_result.def("__eq__", &OpResult::operator==) op_result.def("__eq__", &OpResult::operator==)
.def("__eq__", .def("__eq__",
...@@ -301,7 +390,42 @@ void BindUtils(pybind11::module *m) { ...@@ -301,7 +390,42 @@ void BindUtils(pybind11::module *m) {
[]() { APIBuilder::Instance().ResetInsertionPointToStart(); }); []() { APIBuilder::Instance().ResetInsertionPointToStart(); });
m->def("reset_insertion_point_to_end", m->def("reset_insertion_point_to_end",
[]() { APIBuilder::Instance().ResetInsertionPointToEnd(); }); []() { APIBuilder::Instance().ResetInsertionPointToEnd(); });
m->def("translate_to_new_ir", &paddle::TranslateLegacyProgramToProgram); m->def("translate_to_new_ir", &paddle::TranslateLegacyProgramToProgram, R"DOC(
Convert Fluid Program to New IR Program.
Args:
legacy_program (ProgramDesc): The Fluid Program that will be converted.
Returns:
Program: The New IR Program
Raises:
PreconditionNotMet: If legacy_program has multi block will raise error.
Examples:
.. code-block:: python
import paddle
from paddle import ir
paddle.enable_static()
x = paddle.randn([4, 4])
main_program, start_program = (
paddle.static.Program(),
paddle.static.Program(),
)
with paddle.static.program_guard(main_program, start_program):
x_s = paddle.static.data('x', [4, 4], x.dtype)
x_s.stop_gradient = False
y_s = paddle.matmul(x_s, x_s)
z_s = paddle.add(y_s, y_s)
k_s = paddle.tanh(z_s)
newir_program = ir.translate_to_new_ir(main_program.desc)
print(newir_program)
)DOC");
} }
void BindNewIR(pybind11::module *module) { void BindNewIR(pybind11::module *module) {
......
...@@ -1016,32 +1016,8 @@ def convert_np_dtype_to_dtype_(np_dtype): ...@@ -1016,32 +1016,8 @@ def convert_np_dtype_to_dtype_(np_dtype):
dtype = np.dtype(np_dtype) dtype = np.dtype(np_dtype)
if ir.core._use_new_ir_api(): if ir.core._use_new_ir_api():
if dtype == np.float32: if dtype in ir.core.np_type_to_paddle_type.keys():
return core.DataType.FLOAT32 return ir.core.np_type_to_paddle_type[dtype]
elif dtype == np.float64:
return core.DataType.FLOAT64
elif dtype == np.float16:
return core.DataType.FLOAT16
elif dtype == np.int32:
return core.DataType.INT32
elif dtype == np.int16:
return core.DataType.INT16
elif dtype == np.int64:
return core.DataType.INT64
elif dtype == np.bool_:
return core.DataType.BOOL
elif dtype == np.uint16:
# since there is still no support for bfloat16 in NumPy,
# uint16 is used for casting bfloat16
return core.DataType.UINT16
elif dtype == np.uint8:
return core.DataType.UINT8
elif dtype == np.int8:
return core.DataType.INT8
elif dtype == np.complex64:
return core.DataType.COMPLEX64
elif dtype == np.complex128:
return core.DataType.COMPLEX128
else: else:
raise ValueError("Not supported numpy dtype %s" % dtype) raise ValueError("Not supported numpy dtype %s" % dtype)
else: else:
......
...@@ -31,13 +31,4 @@ from paddle.fluid.libpaddle.ir import ( ...@@ -31,13 +31,4 @@ from paddle.fluid.libpaddle.ir import (
from . import core from . import core
__all__ = [ # noqa __all__ = []
'Program',
'Block',
'Operation',
'Value',
'OpOperand',
'OpResult',
'Type',
'translate_to_new_ir',
]
...@@ -13,11 +13,29 @@ ...@@ -13,11 +13,29 @@
# limitations under the License. # limitations under the License.
import numpy as np
import paddle import paddle
from paddle.fluid.libpaddle import DataType
from paddle.fluid.libpaddle.ir import Program, set_global_program from paddle.fluid.libpaddle.ir import Program, set_global_program
from ..fluid.wrapped_decorator import signature_safe_contextmanager from ..fluid.wrapped_decorator import signature_safe_contextmanager
np_type_to_paddle_type = {
np.dtype("float32"): DataType.FLOAT32,
np.dtype("float64"): DataType.FLOAT64,
np.dtype("float16"): DataType.FLOAT16,
np.dtype("int32"): DataType.INT32,
np.dtype("int16"): DataType.INT16,
np.dtype("int64"): DataType.INT64,
np.dtype("bool_"): DataType.BOOL,
np.dtype("uint16"): DataType.UINT16,
np.dtype("uint8"): DataType.UINT8,
np.dtype("int8"): DataType.INT8,
np.dtype("complex64"): DataType.COMPLEX64,
np.dtype("complex128"): DataType.COMPLEX128,
}
def _use_new_ir_api(): def _use_new_ir_api():
""" """
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册