未验证 提交 e66beb0b 编写于 作者: H hong 提交者: GitHub

Support fetch in new ir (#54826)

* add fetch kernel

* support fetch var in new ir

* fix bug

* polish code

* change array equal to np.testing
上级 03d6d98c
......@@ -952,10 +952,11 @@ void BuildOpFuncList(
auto op_name = attr_map.at("op_name").dyn_cast<::ir::StrAttribute>().data();
if (op_name == "pd.fetch" || op_name == "builtin.combine") {
if (op_name == "builtin.combine") {
VLOG(6) << "skip process pd.fetch op";
continue;
}
op_func_node.phi_op_name_ = op_name;
::ir::OpInfo op_info = ctx->GetRegisteredOpInfo(op_name);
......
......@@ -68,9 +68,11 @@ StandaloneExecutor::StandaloneExecutor(const platform::Place& place,
if (FLAGS_enable_new_ir_in_executor) {
VLOG(6) << "begin to translate" << std::endl;
auto base_progrm = paddle::TranslateLegacyProgramToProgram(*program);
auto base_program = paddle::TranslateLegacyProgramToProgram(*program);
auto kernel_program =
paddle::dialect::PdOpLowerToKernelPass(base_progrm.get());
paddle::dialect::PdOpLowerToKernelPass(base_program.get());
interpretercores_.emplace_back(std::make_unique<InterpreterCore>(
place_, std::move(kernel_program), scope_, execution_config));
} else {
......
......@@ -17,8 +17,20 @@
data_transform: {}
attrs:
- {typename: str, name: name}
outputs: []
outputs:
- {typename: Tensor, name: out, optional: false, intermediate: false}
no_need_buffer: null
data_transform: null
infer_meta:
func: UnchangedInferMeta
param: [x]
kernel:
func: [fetch]
param: [x]
backend: null
layout: null
data_type: null
dispatch: {fetch: null}
force_backend: null
inplace: null
backward: null
......@@ -55,15 +55,14 @@ class PhiKernelAdaptor {
void run(ir::Program* program) {
auto block = program->block();
std::unordered_map<ir::Value, std::string> name_map;
std::cerr << "run here" << std::endl;
ir::BuildScope(block, scope_, &name_map);
std::cerr << "after buid scope" << std::endl;
auto* dev_ctx = phi::DeviceContextPool::Instance().Get(phi::CPUPlace());
phi::Place cpu_place(phi::AllocationType::CPU);
for (auto it = block->begin(); it != block->end(); ++it) {
VLOG(6) << "begin to run op " << (*it)->name();
std::cerr << (*it)->name() << std::endl;
auto attr_map = (*it)->attributes();
paddle::dialect::OpYamlInfoInterface op_info_interface =
......@@ -91,10 +90,6 @@ class PhiKernelAdaptor {
}
auto found_it = phi_kernels.find(kernel_key);
if (found_it == phi_kernels.end()) {
std::cerr << "kernel name " << runtime_info.kernel_func[0] << std::endl;
std::cerr << "kernel key " << kernel_key.backend() << "\t"
<< kernel_key.dtype() << "\t" << kernel_key.layout()
<< std::endl;
PADDLE_THROW(paddle::platform::errors::NotFound(
"can not found kerenl for [%s]", (*it)->name()));
} else {
......
......@@ -35,6 +35,7 @@
#include "paddle/fluid/framework/tensor_ref_array.h"
#include "paddle/fluid/ir/dialect/kernel_attribute.h"
#include "paddle/fluid/ir/dialect/pd_attribute.h"
#include "paddle/phi/core/enforce.h"
#include "glog/logging.h"
......@@ -57,18 +58,10 @@ void BuildScope(ir::Block* block,
if (op_name == "pd.fetch") {
// fetch is a very special op, with no output
for (size_t i = 0; i < input_num; ++i) {
auto ptr = (*it)->operand(i).source();
auto var_name = attr_map.at("name").dyn_cast<ir::StrAttribute>().data();
PADDLE_ENFORCE_EQ(
name_map->count(ptr),
true,
phi::errors::PreconditionNotMet(
"input of fetch op should in name mape, var_name is [%s]",
var_name));
scope->Rename(name_map->at(ptr), var_name);
(*name_map)[ptr] = var_name;
auto var = scope->Var("fetch");
auto fetch_list = var->GetMutable<paddle::framework::FetchList>();
// for now only support one fetch
fetch_list->resize(1);
}
continue;
}
......@@ -245,10 +238,21 @@ void BuildInferMetaContext(
}
}
// update here, support fetch list for now
// [todo update here]
if (op->attributes().count("op_name") &&
(op->attributes().at("op_name").dyn_cast<ir::StrAttribute>().data() ==
"pd.fetch")) {
// process fetch op
auto fetch_var = scope->Var("fetch");
auto* fetch_list = fetch_var->GetMutable<paddle::framework::FetchList>();
auto* out_tensor = &(PADDLE_GET(phi::DenseTensor, fetch_list->at(0)));
ctx->EmplaceBackOutput(out_tensor);
} else {
ir::Value out_ptr = op->result(0);
auto name = name_map.at(out_ptr);
ctx->EmplaceBackOutput(scope->Var(name)->Get<phi::DenseTensor>());
}
}
void BuildPhiKernelContext(
......@@ -367,18 +371,27 @@ void BuildPhiKernelContext(
}
}
if (op->attributes().count("op_name") &&
(op->attributes().at("op_name").dyn_cast<ir::StrAttribute>().data() ==
"pd.fetch")) {
// process fetch op
auto fetch_var = scope->Var("fetch");
auto* fetch_list = fetch_var->GetMutable<paddle::framework::FetchList>();
auto* out_tensor = &(PADDLE_GET(phi::DenseTensor, fetch_list->at(0)));
ctx->EmplaceBackOutput(out_tensor);
} else {
ir::Value out_ptr = op->result(0);
auto name = name_map.at(out_ptr);
ctx->EmplaceBackOutput(const_cast<phi::DenseTensor*>(
&(scope->Var(name)->Get<phi::DenseTensor>())));
if (output_map != nullptr) {
// only deal with single input for now, [todo] need support multi input like
// concat
// only deal with single input for now, [todo] need support multi input
// like concat
size_t tmp_id = std::atol(name.substr(4, 100).c_str());
(*output_map)["out"].push_back(tmp_id);
}
}
}
} // namespace ir
......@@ -572,6 +572,7 @@ ir::Operation* FetchOpHandler(ir::IrContext* ctx,
{"name", ir::StrAttribute::get(ctx, op_desc.InputArgumentNames()[0])},
};
op_output_types.push_back(op_inputs[0].type());
ir::Operation* operation =
ir::Operation::Create(op_inputs, attribute_map, op_output_types, op_info);
program->block()->push_back(operation);
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/fetch_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
namespace phi {
template <typename T, typename Context>
void FetchKernel(const Context& dev_ctx,
const DenseTensor& x,
DenseTensor* out) {
phi::Copy(dev_ctx, x, phi::CPUPlace(), true, out);
}
} // namespace phi
PD_REGISTER_KERNEL(fetch,
CPU,
ALL_LAYOUT,
phi::FetchKernel,
float,
double,
int,
int64_t,
phi::dtype::complex<float>,
phi::dtype::complex<double>,
bool) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
/**
* @brief This kernel is used to fetch tensor from scope
* @param ctx device context
* @param x the input tensor of fetch
* @param out the output tensor of fetch
*/
template <typename T, typename Context>
void FetchKernel(const Context& ctx, const DenseTensor& x, DenseTensor* out);
} // namespace phi
......@@ -22,30 +22,22 @@ import paddle
paddle.enable_static()
# class TestNewIr(unittest.TestCase):
# def test_with_new_ir(self):
# place = paddle.CPUPlace()
# exe = paddle.static.Executor(place)
class TestNewIr(unittest.TestCase):
def test_with_new_ir(self):
place = paddle.CPUPlace()
exe = paddle.static.Executor(place)
# x = paddle.ones([2, 2], dtype="float32")
# y = paddle.ones([2, 2], dtype="float32")
x = paddle.ones([2, 2], dtype="float32")
y = paddle.ones([2, 2], dtype="float32")
# z = x + y
# out = exe.run(
# paddle.static.default_main_program(), {}, fetch_list=[z.name]
# )
z = x + y
out = exe.run(
paddle.static.default_main_program(), {}, fetch_list=[z.name]
)
# gold_res = np.ones([2, 2], dtype="float32") * 2
gold_res = np.ones([2, 2], dtype="float32") * 2
# self.assertEqual(
# np.array_equal(
# np.array(
# paddle.static.global_scope().find_var(z.name).get_tensor()
# ),
# gold_res,
# ),
# True,
# )
np.testing.assert_array_equal(out[0], gold_res)
class TestCombineOp(unittest.TestCase):
......@@ -63,15 +55,7 @@ class TestCombineOp(unittest.TestCase):
gold_res = np.ones([2, 2], dtype="float32") * 2
self.assertEqual(
np.array_equal(
np.array(
paddle.static.global_scope().find_var(z.name).get_tensor()
),
gold_res,
),
True,
)
np.testing.assert_array_equal(out[0], gold_res)
if __name__ == "__main__":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册