未验证 提交 a88d36aa 编写于 作者: H hong 提交者: GitHub

New ir support save combine (#55538)

* new ir support save combine

* update

* polish code
上级 74266762
......@@ -974,18 +974,20 @@ void BuildOpFuncList(
VLOG(6) << "op name" << op_func_node.phi_op_name_;
dialect::OpYamlInfoParser op_yaml_info_parser(impl->get_op_info_());
::ir::BuildPhiContext<
phi::InferMetaContext,
phi::MetaTensor,
phi::MetaTensor,
paddle::small_vector<phi::MetaTensor, phi::kInputSmallVectorSize>,
paddle::small_vector<phi::MetaTensor, phi::kInputSmallVectorSize>,
false>((*it),
value_2_name_map,
scope,
local_scope,
op_yaml_info_parser,
&(op_func_node.infer_meta_context_));
if (op_func_node.infer_meta_interface_) {
::ir::BuildPhiContext<
phi::InferMetaContext,
phi::MetaTensor,
phi::MetaTensor,
paddle::small_vector<phi::MetaTensor, phi::kInputSmallVectorSize>,
paddle::small_vector<phi::MetaTensor, phi::kInputSmallVectorSize>,
false>((*it),
value_2_name_map,
scope,
local_scope,
op_yaml_info_parser,
&(op_func_node.infer_meta_context_));
}
auto kernel_name =
attr_map.at("kernel_name").dyn_cast<ir::StrAttribute>().AsString();
......
......@@ -161,7 +161,7 @@ Instruction::Instruction(size_t id,
is_artificial_ = true;
}
if (op_func_node_.infer_meta_interface_ != nullptr) {
if (op_func_node_.phi_kernel_ != nullptr) {
pre_define_context_ = true;
}
PADDLE_ENFORCE_GE(id,
......
......@@ -1028,8 +1028,10 @@ void NewIRInterpreter::RunInstruction(const Instruction& instr_node) {
VLOG(5) << "run new ir selected kernel";
auto op_func_node = const_cast<OpFuncNode*>((instr_node.OpFunc()));
VLOG(5) << "begin to run op " << op_func_node->phi_op_name_;
op_func_node->infer_meta_interface_->infer_meta_(
&(op_func_node->infer_meta_context_));
if (op_func_node->infer_meta_interface_) {
op_func_node->infer_meta_interface_->infer_meta_(
&(op_func_node->infer_meta_context_));
}
VLOG(5) << "after run infer meta";
(*(op_func_node->phi_kernel_))(&(op_func_node->kernel_context_));
VLOG(5) << "after run kernel";
......
......@@ -38,6 +38,33 @@
inplace: null
backward: null
- name: save_combine
inputs:
- typename: Tensor[]
name: X
optional: false
no_need_buffer: false
data_transform: {}
attrs:
- {typename: str, name: file_path}
- {typename: bool, name: overwrite}
- {typename: bool, name: save_as_fp16}
- {typename: bool, name: save_to_memory}
outputs:
- {typename: Tensor, name: out, optional: true, intermediate: false}
no_need_buffer: null
data_transform: null
kernel:
func: [save_combine_tensor]
param: [X, file_path, overwrite, save_as_fp16, save_to_memory]
backend: null
layout: null
data_type: null
dispatch: {fetch: null}
force_backend: null
inplace: null
backward: null
- name: share_buffer_
inputs:
- typename: Tensor[]
......
......@@ -106,7 +106,6 @@ class TestSelectedRows(unittest.TestCase):
def test_with_new_ir(self):
paddle.enable_static()
# TODO(phlrain): support selected rows in GPU
# place = paddle.CUDAPlace(0) if paddle.is_compiled_with_cuda() else paddle.CPUPlace()
place = paddle.CPUPlace()
exe = paddle.static.Executor(place)
......@@ -241,6 +240,19 @@ class TestSplitOp(unittest.TestCase):
np.testing.assert_array_equal(out[0], np_a[0:2])
class TestJitSaveOp(unittest.TestCase):
def test_with_new_ir(self):
paddle.disable_static()
linear = paddle.nn.Linear(10, 10)
path = "example_model/linear"
paddle.jit.save(
linear,
path,
input_spec=[paddle.static.InputSpec([10, 10], 'float32', 'x')],
)
if __name__ == "__main__":
paddle.enable_static()
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册