提交 aefca7f1 编写于 作者: S superjomn

add files

上级 eada00c2
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/lite/core/mir/pass_registry.h"
namespace paddle {
namespace lite {
namespace mir {} // namespace mir
} // namespace lite
} // namespace paddle
USE_MIR_PASS(demo);
USE_MIR_PASS(static_kernel_pick_pass);
USE_MIR_PASS(io_complement_pass);
USE_MIR_PASS(generate_program_pass);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/core/optimizer.h"
#include <gtest/gtest.h>
#include "paddle/fluid/lite/core/mir/pass_manager.h"
#include "paddle/fluid/lite/core/mir/passes.h"
#include "paddle/fluid/lite/core/mir/static_kernel_pick_pass.h"
#include "paddle/fluid/lite/core/program_fake_utils.h"
namespace paddle {
namespace lite {
TEST(Optimizer, test) {
Optimizer optimizer;
auto program = FakeProgram();
std::vector<Place> places({Place{TARGET(kHost), PRECISION(kFloat)}});
auto* pick_pass =
mir::PassManager::Global().LookUp<mir::StaticKernelPickPass>(
"static_kernel_pick_pass");
ASSERT_TRUE(pick_pass != nullptr);
pick_pass->mutable_kernel_pick_factors()
->ConsiderTarget()
.ConsiderPrecision();
optimizer.Run(std::move(program), places);
}
} // namespace lite
} // namespace paddle
USE_LITE_OP(fc);
USE_LITE_KERNEL(fc, kHost, kFloat);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/core/program_fake_utils.h"
#include "paddle/fluid/lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace mir {} // namespace mir
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include "paddle/fluid/lite/core/mir/ssa_graph.h"
#include "paddle/fluid/lite/core/op_registry.h"
namespace paddle {
namespace lite {
mir::Program FakeProgram() {
mir::Program program;
program.scope = new lite::Scope;
auto add_fc = [&](int id, std::string x) {
// create variables
std::string w1 = "w" + std::to_string(id);
std::string b1 = "b" + std::to_string(id);
std::string out1 = "out" + std::to_string(id);
auto w1v = program.scope->Var(w1)->GetMutable<Tensor>();
auto b1v = program.scope->Var(b1)->GetMutable<Tensor>();
auto out1v = program.scope->Var(out1)->GetMutable<Tensor>();
framework::OpDesc desc;
desc.SetInput("Input", {x});
desc.SetInput("W", {w1});
desc.SetInput("Bias", {b1});
desc.SetOutput("Out", {out1});
desc.SetType("fc");
desc.SetAttr("in_num_col_dims", 1);
desc.Flush();
// add to input
program.tmp_vars.push_back(w1);
program.tmp_vars.push_back(b1);
auto fc_op = LiteOpRegistry::Global().Create("fc");
fc_op->PickKernel({Place{TARGET(kHost), PRECISION(kFloat)}});
fc_op->Attach(desc, program.scope);
program.ops.emplace_back(std::move(fc_op));
w1v->Resize({100, 100});
b1v->Resize({100, 1});
out1v->Resize({100, 100});
return out1;
};
// x1, w1, b1 -fc-> out1
// out1, w2, b2 -fc-> out2
std::string x = "x";
program.tmp_vars.push_back(x);
auto* xv = program.scope->Var(x)->GetMutable<Tensor>();
xv->Resize({100, 100});
for (int i = 0; i < 3; i++) {
x = add_fc(i, x);
}
return program;
}
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/core/types.h"
namespace paddle {
namespace lite {
namespace core {
KernelPickFactor& KernelPickFactor::ConsiderDataLayout() {
data_ |= static_cast<int>(Factor::DataLayoutFirst);
return *this;
}
KernelPickFactor& KernelPickFactor::ConsiderPrecision() {
data_ |= static_cast<int>(Factor::PrecisionFirst);
return *this;
}
KernelPickFactor& KernelPickFactor::ConsiderTarget() {
data_ |= static_cast<int>(Factor::TargetFirst);
return *this;
}
KernelPickFactor& KernelPickFactor::ConsiderDevice() {
data_ |= static_cast<int>(Factor::DeviceFirst);
return *this;
}
bool KernelPickFactor::IsPrecisionConsidered() const {
return data_ & static_cast<int>(Factor::PrecisionFirst);
}
bool KernelPickFactor::IsTargetConsidered() const {
return data_ & static_cast<int>(Factor::TargetFirst);
}
bool KernelPickFactor::IsDataLayoutConsidered() const {
return data_ & static_cast<int>(Factor::DataLayoutFirst);
}
} // namespace core
} // namespace lite
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册