提交 598c2a5f 编写于 作者: H huzhiqiang 提交者: GitHub

[opencl]add pre_process attribute into layoutop (#3001)

上级 4fbbdfc6
...@@ -35,7 +35,13 @@ void CxxPaddleApiImpl::Init(const lite_api::CxxConfig &config) { ...@@ -35,7 +35,13 @@ void CxxPaddleApiImpl::Init(const lite_api::CxxConfig &config) {
Env<TARGET(kCUDA)>::Init(); Env<TARGET(kCUDA)>::Init();
#endif #endif
auto places = config.valid_places(); auto places = config.valid_places();
raw_predictor_.Build(config, places); std::vector<std::string> passes{};
auto use_layout_preprocess_pass =
config.model_dir().find("OPENCL_PRE_PRECESS");
if (use_layout_preprocess_pass != std::string::npos) {
passes = {"type_layout_cast_preprocess_pass"};
}
raw_predictor_.Build(config, places, passes);
mode_ = config.power_mode(); mode_ = config.power_mode();
threads_ = config.threads(); threads_ = config.threads();
......
...@@ -204,6 +204,28 @@ void TypeLayoutTransformPass::SetValidPlaces( ...@@ -204,6 +204,28 @@ void TypeLayoutTransformPass::SetValidPlaces(
valid_places_ = valid_places; valid_places_ = valid_places;
} }
void OpenCLTypeLayoutTransformPass::Apply(
const std::unique_ptr<SSAGraph>& graph) {
// Start from inputs of the graph, those should have place set.
VLOG(4) << "\n" << Visualize(graph.get());
std::list<Node*> nodes;
for (auto& node : graph->StmtTopologicalOrder()) {
nodes.push_back(node);
}
VLOG(4) << "nodes.size():" << nodes.size();
for (auto& node : nodes) {
VLOG(4) << "!node->IsStmt():" << !node->IsStmt();
if (!node->IsStmt() || node->AsStmt().op_type() == "while") continue;
if (node->AsStmt().op_type() == "layout") {
auto new_op = node->AsStmt().mutable_op_info();
int process_type = 1;
new_op->SetAttr("process_type", process_type);
}
}
VLOG(4) << "\n" << Visualize(graph.get());
}
} // namespace mir } // namespace mir
} // namespace lite } // namespace lite
} // namespace paddle } // namespace paddle
...@@ -213,3 +235,9 @@ REGISTER_MIR_PASS(type_layout_cast_pass, ...@@ -213,3 +235,9 @@ REGISTER_MIR_PASS(type_layout_cast_pass,
.BindTargets({TARGET(kAny)}) .BindTargets({TARGET(kAny)})
.BindKernel("layout_once") .BindKernel("layout_once")
.BindKernel("layout"); .BindKernel("layout");
REGISTER_MIR_PASS(type_layout_cast_preprocess_pass,
paddle::lite::mir::OpenCLTypeLayoutTransformPass)
.BindTargets({TARGET(kAny)})
.BindKernel("layout_once")
.BindKernel("layout");
...@@ -57,6 +57,15 @@ class TypeLayoutTransformPass : public ProgramPass { ...@@ -57,6 +57,15 @@ class TypeLayoutTransformPass : public ProgramPass {
std::vector<Place> valid_places_; std::vector<Place> valid_places_;
}; };
// add preprocess and postprocess attribute for layout op
class OpenCLTypeLayoutTransformPass : public ProgramPass {
public:
void Apply(const std::unique_ptr<SSAGraph>& graph) override;
private:
std::vector<Place> valid_places_;
};
} // namespace mir } // namespace mir
} // namespace lite } // namespace lite
} // namespace paddle } // namespace paddle
...@@ -53,7 +53,7 @@ class Optimizer { ...@@ -53,7 +53,7 @@ class Optimizer {
SpecifyKernelPickTactic(kernel_pick_factor); SpecifyKernelPickTactic(kernel_pick_factor);
InitTargetTypeTransformPass(); InitTargetTypeTransformPass();
if (passes.empty()) { if (passes.empty() || passes.size() == 1) {
std::vector<std::string> passes_local{ std::vector<std::string> passes_local{
{"lite_quant_dequant_fuse_pass", // {"lite_quant_dequant_fuse_pass", //
"weight_quantization_preprocess_pass", // "weight_quantization_preprocess_pass", //
...@@ -112,6 +112,9 @@ class Optimizer { ...@@ -112,6 +112,9 @@ class Optimizer {
"runtime_context_assign_pass", "runtime_context_assign_pass",
"argument_type_display_pass", "argument_type_display_pass",
"memory_optimize_pass"}}; "memory_optimize_pass"}};
if (passes.size() == 1) {
passes_local.push_back(passes[0]);
}
RunPasses(passes_local); RunPasses(passes_local);
} else { } else {
RunPasses(passes); RunPasses(passes);
......
...@@ -35,6 +35,9 @@ bool LayoutOp::AttachImpl(const cpp::OpDesc &opdesc, ...@@ -35,6 +35,9 @@ bool LayoutOp::AttachImpl(const cpp::OpDesc &opdesc,
auto out = opdesc.Output("Out").front(); auto out = opdesc.Output("Out").front();
param_.x = GetTensor(scope, x); param_.x = GetTensor(scope, x);
param_.y = GetMutableTensor(scope, out); param_.y = GetMutableTensor(scope, out);
if (opdesc.HasAttr("process_type")) {
param_.process_type = opdesc.GetAttr<int>("process_type");
}
return true; return true;
} }
std::string LayoutOp::DebugString() const { return "layout_op"; } std::string LayoutOp::DebugString() const { return "layout_op"; }
......
...@@ -62,6 +62,7 @@ struct IoCopyParam { ...@@ -62,6 +62,7 @@ struct IoCopyParam {
struct LayoutParam { struct LayoutParam {
const lite::Tensor* x{}; const lite::Tensor* x{};
lite::Tensor* y{}; lite::Tensor* y{};
int process_type{0};
}; };
struct CalibParam { struct CalibParam {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册