未验证 提交 360b4013 编写于 作者: H huzhiqiang 提交者: GitHub

[opencl]add pre_process attribute into layoutop (#3001)

上级 857b7116
......@@ -35,7 +35,13 @@ void CxxPaddleApiImpl::Init(const lite_api::CxxConfig &config) {
Env<TARGET(kCUDA)>::Init();
#endif
auto places = config.valid_places();
raw_predictor_.Build(config, places);
std::vector<std::string> passes{};
auto use_layout_preprocess_pass =
config.model_dir().find("OPENCL_PRE_PRECESS");
if (use_layout_preprocess_pass != std::string::npos) {
passes = {"type_layout_cast_preprocess_pass"};
}
raw_predictor_.Build(config, places, passes);
mode_ = config.power_mode();
threads_ = config.threads();
......
......@@ -204,6 +204,28 @@ void TypeLayoutTransformPass::SetValidPlaces(
valid_places_ = valid_places;
}
void OpenCLTypeLayoutTransformPass::Apply(
const std::unique_ptr<SSAGraph>& graph) {
// Start from inputs of the graph, those should have place set.
VLOG(4) << "\n" << Visualize(graph.get());
std::list<Node*> nodes;
for (auto& node : graph->StmtTopologicalOrder()) {
nodes.push_back(node);
}
VLOG(4) << "nodes.size():" << nodes.size();
for (auto& node : nodes) {
VLOG(4) << "!node->IsStmt():" << !node->IsStmt();
if (!node->IsStmt() || node->AsStmt().op_type() == "while") continue;
if (node->AsStmt().op_type() == "layout") {
auto new_op = node->AsStmt().mutable_op_info();
int process_type = 1;
new_op->SetAttr("process_type", process_type);
}
}
VLOG(4) << "\n" << Visualize(graph.get());
}
} // namespace mir
} // namespace lite
} // namespace paddle
......@@ -213,3 +235,9 @@ REGISTER_MIR_PASS(type_layout_cast_pass,
.BindTargets({TARGET(kAny)})
.BindKernel("layout_once")
.BindKernel("layout");
REGISTER_MIR_PASS(type_layout_cast_preprocess_pass,
paddle::lite::mir::OpenCLTypeLayoutTransformPass)
.BindTargets({TARGET(kAny)})
.BindKernel("layout_once")
.BindKernel("layout");
......@@ -57,6 +57,15 @@ class TypeLayoutTransformPass : public ProgramPass {
std::vector<Place> valid_places_;
};
// add preprocess and postprocess attribute for layout op
class OpenCLTypeLayoutTransformPass : public ProgramPass {
public:
void Apply(const std::unique_ptr<SSAGraph>& graph) override;
private:
std::vector<Place> valid_places_;
};
} // namespace mir
} // namespace lite
} // namespace paddle
......@@ -53,7 +53,7 @@ class Optimizer {
SpecifyKernelPickTactic(kernel_pick_factor);
InitTargetTypeTransformPass();
if (passes.empty()) {
if (passes.empty() || passes.size() == 1) {
std::vector<std::string> passes_local{
{"lite_quant_dequant_fuse_pass", //
"weight_quantization_preprocess_pass", //
......@@ -112,6 +112,9 @@ class Optimizer {
"runtime_context_assign_pass",
"argument_type_display_pass",
"memory_optimize_pass"}};
if (passes.size() == 1) {
passes_local.push_back(passes[0]);
}
RunPasses(passes_local);
} else {
RunPasses(passes);
......
......@@ -35,6 +35,9 @@ bool LayoutOp::AttachImpl(const cpp::OpDesc &opdesc,
auto out = opdesc.Output("Out").front();
param_.x = GetTensor(scope, x);
param_.y = GetMutableTensor(scope, out);
if (opdesc.HasAttr("process_type")) {
param_.process_type = opdesc.GetAttr<int>("process_type");
}
return true;
}
std::string LayoutOp::DebugString() const { return "layout_op"; }
......
......@@ -62,6 +62,7 @@ struct IoCopyParam {
struct LayoutParam {
const lite::Tensor* x{};
lite::Tensor* y{};
int process_type{0};
};
struct CalibParam {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册