未验证 提交 2479664a 编写于 作者: Z Zhen Wang 提交者: GitHub

Add FLAGS_allow_cinn_ops & FLAGS_deny_cinn_ops for controlling op types used...

Add FLAGS_allow_cinn_ops & FLAGS_deny_cinn_ops for controlling op types used in training with CINN. (#36842)

* Update UT test_parallel_executor_run_cinn.py.

* Add FLAGS_allow_cinn_ops & FLAGS_deny_cinn_ops & FLAGS_cinn_ops_delim.

* Use the custom StringSplit function and remove the FLAGS_cinn_ops_delim flag.

* Add FlagController test.

* Apply lock to the cache_ only in CinnCompiler.

* Add VizGraph & ReadableKey method for CinnCompiler.

* Update the dot style of VizGraph in CinnCompiler.
上级 fb394695
......@@ -2,7 +2,7 @@ cc_library(cinn_cache_key SRCS cinn_cache_key.cc DEPS boost graph graph_helper l
cc_library(build_cinn_pass SRCS build_cinn_pass.cc DEPS pass subgraph_detector graph_pattern_detector cinn_compiler errors enforce)
cc_library(transform_desc SRCS transform_desc.cc DEPS proto_desc cinn)
cc_library(cinn_graph_symbolization SRCS cinn_graph_symbolization.cc DEPS lod_tensor graph transform_desc cinn)
cc_library(cinn_compiler SRCS cinn_compiler.cc DEPS graph lod_tensor cinn_cache_key cinn_graph_symbolization cinn)
cc_library(cinn_compiler SRCS cinn_compiler.cc DEPS framework_proto graph lod_tensor cinn_cache_key cinn_graph_symbolization cinn)
cc_test(cinn_cache_key_test SRCS cinn_cache_key_test.cc DEPS cinn_cache_key)
cc_test(build_cinn_pass_test SRCS build_cinn_pass_test.cc DEPS build_cinn_pass cinn_compiler)
......
......@@ -17,6 +17,7 @@ limitations under the License. */
#include <algorithm>
#include <iterator>
#include <memory>
#include <regex>
#include <string>
#include <unordered_map>
#include <unordered_set>
......@@ -25,6 +26,8 @@ limitations under the License. */
#include "cinn/frontend/op_mapper_registry.h"
#include "cinn/frontend/op_mappers/use_op_mappers.h"
#include "gflags/gflags.h"
#include "glog/logging.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/node.h"
......@@ -34,6 +37,9 @@ limitations under the License. */
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/errors.h"
DECLARE_string(allow_cinn_ops);
DECLARE_string(deny_cinn_ops);
namespace paddle {
namespace framework {
namespace paddle2cinn {
......@@ -46,6 +52,20 @@ using GraphNodeSet = std::unordered_set<Node*>;
using GraphNodeMap = std::unordered_map<Node*, Node*>;
namespace {
// The delim(`;`) that is used to split the FLAGS_allow_cinn_ops
// & FLAGS_deny_cinn_ops.
constexpr char kDelim[] = ";";
std::unordered_set<std::string> StringSplit(const std::string& str,
const std::string& delim) {
std::regex reg(delim);
std::unordered_set<std::string> elems{
std::sregex_token_iterator(str.begin(), str.end(), reg, -1),
std::sregex_token_iterator()};
elems.erase("");
return elems;
}
int ExtractOpRole(const GraphNodeSet& cluster) {
std::unordered_set<int> op_roles;
std::string attr_name = OpProtoAndCheckerMaker::OpRoleAttrName();
......@@ -339,10 +359,27 @@ void ReplaceSubGraphWithCinnOpNode(const GraphNodeSet& cluster,
// all of op node supported by CINN. We using OpMapperRegistry
// to check whether the op node supported by CINN.
void SearchAllSubgraphs(Graph* graph) {
auto teller = [](const Node* node) {
return ::cinn::frontend::OpMapperRegistry::Global()->Find(node->Name()) !=
nullptr;
auto allow_ops = StringSplit(FLAGS_allow_cinn_ops, kDelim);
auto deny_ops = StringSplit(FLAGS_deny_cinn_ops, kDelim);
auto teller = [&allow_ops, &deny_ops](const Node* node) {
bool registered = ::cinn::frontend::OpMapperRegistry::Global()->Find(
node->Name()) != nullptr;
// if the op type is registered in CINN and allow_ops is not empty, return
// true only when it is in allow_ops
if (allow_ops.size()) {
return registered && allow_ops.count(node->Name());
}
// if the op type is registered in CINN and deny_ops is not empty, return
// true only when it is not in deny_ops
if (deny_ops.size()) {
return registered && !deny_ops.count(node->Name());
}
// if the user doesn't set FLAGS_allow_cinn_ops and FLAGS_deny_cinn_ops,
// return true only when it is registered in CINN
return registered;
};
VLOG(4) << "The allowed Cinn Ops: " << FLAGS_allow_cinn_ops;
VLOG(4) << "The denied Cinn Ops: " << FLAGS_deny_cinn_ops;
std::vector<GraphNodeVec> clusters =
framework::ir::SubgraphDetector(graph, teller)();
......@@ -375,7 +412,8 @@ void SearchAllSubgraphs(Graph* graph) {
// save it in CinnCompiler
std::string compilation_key = cinn_compiler->AddGraph(CreateNewSubGraph(
cluster_set, cluster_internals, cluster_inputs, cluster_outputs));
VLOG(4) << "Compilation Key: " << compilation_key;
VLOG(4) << "Compilation Key:\n"
<< cinn_compiler->ReadableKey(compilation_key);
// Replace the found cluster to a new cinn op node
ReplaceSubGraphWithCinnOpNode(cluster_set, cluster_inputs, cluster_outputs,
......
......@@ -14,14 +14,15 @@
#include "paddle/fluid/framework/paddle2cinn/cinn_compiler.h"
#include <iterator>
#include <map>
#include <memory>
#include <string>
#include <unordered_map>
#include "cinn/common/target.h"
#include "cinn/common/type.h"
#include "cinn/frontend/decomposer/use_decomposer.h"
#include "cinn/frontend/net_builder.h" // need to remove after
#include "cinn/frontend/pass/use_program_pass.h"
#include "cinn/frontend/program_pass.h"
#include "cinn/frontend/syntax.h"
......@@ -29,19 +30,26 @@
#include "cinn/hlir/framework/graph_compiler.h"
#include "cinn/hlir/framework/pass.h"
#include "cinn/hlir/pass/use_pass.h"
#include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/ir/node.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/paddle2cinn/cinn_graph_symbolization.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/rw_lock.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/inference/analysis/dot.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/string/string_helper.h"
namespace paddle {
namespace framework {
namespace paddle2cinn {
using ir::Graph;
using ir::Node;
using inference::analysis::Dot;
using ::cinn::common::Target;
using ::cinn::common::Float;
using ::cinn::hlir::framework::GraphCompiler;
......@@ -54,47 +62,121 @@ CinnCompiler* CinnCompiler::GetInstance() {
return &instance;
}
const CinnCompiledObject& CinnCompiler::Compile(
const Graph& graph,
const std::map<std::string, const LoDTensor*>& input_tensors,
const Target& target) {
CinnCacheKey cur_key(graph, input_tensors, target.arch_str());
bool exist = false;
{
AutoRDLock r_guard{&rwlock_};
exist = cache_.count(cur_key) != 0;
}
if (!exist) {
real_compiled_num_++;
auto compiled_res = CompileGraph(graph, input_tensors, target);
AutoWRLock w_guard{&rwlock_};
if (!cache_.count(cur_key)) {
cache_[cur_key] = std::move(compiled_res);
}
}
AutoRDLock guard{&rwlock_};
const auto& cached_boj = *cache_[cur_key];
return cached_boj;
}
const CinnCompiledObject& CinnCompiler::Compile(
const std::string& compilation_key,
const std::map<std::string, const LoDTensor*>& input_tensors,
const Target& target) {
VLOG(4) << "-- The graph to be compiled is:\n" << VizGraph(compilation_key);
const auto& graph = FindGraph(compilation_key);
return Compile(graph, input_tensors, target);
}
std::string CinnCompiler::AddGraph(std::unique_ptr<Graph> graph) {
std::string graph_key;
ProgramDesc program;
GraphToProgram(*graph, &program);
program.Proto()->SerializeToString(&graph_key);
if (!graphs_.count(graph_key)) {
PADDLE_ENFORCE_EQ(
graphs_.count(graph_key), 0,
platform::errors::PreconditionNotMet(
"The graph to be added is already in CinnCompiler, which is:\n",
VizGraph(graph_key).c_str()));
graphs_[graph_key] = std::move(graph);
} else {
LOG(WARNING)
<< "The graph being added is already in CinnCompiler. Its key is:\n"
<< graph_key;
}
VLOG(4) << "-- Add a graph into CinnCompiler, which is:\n"
<< VizGraph(graph_key);
return graph_key;
}
const Graph& CinnCompiler::FindGraph(const std::string& graph_key) const {
PADDLE_ENFORCE_NE(
graphs_.count(graph_key), 0,
platform::errors::InvalidArgument("Can not find the target graph: %s",
graph_key.c_str()));
platform::errors::PreconditionNotMet(
"Can not find the target graph, of which the key is:\n%s",
ReadableKey(graph_key).c_str()));
return *graphs_.at(graph_key);
}
const CinnCompiledObject& CinnCompiler::Compile(
const Graph& graph,
const std::map<std::string, const LoDTensor*>& input_tensors,
const Target& target) {
CinnCacheKey cur_key(graph, input_tensors, target.arch_str());
if (!cache_.count(cur_key)) {
real_compiled_num_++;
cache_[cur_key] = CompileGraph(graph, input_tensors, target);
std::string CinnCompiler::VizGraph(const std::string& key) const {
Dot dot;
std::unordered_map<const Node*, std::string> node2dot;
const Graph& graph = FindGraph(key);
int id = 0;
// Create nodes
for (const Node* n : graph.Nodes()) {
std::string node_id = "Node" + std::to_string(id++);
if (n->IsOp()) {
dot.AddNode(
node_id,
{Dot::Attr("shape", "box"), Dot::Attr("style", "rounded,filled,bold"),
Dot::Attr("color", "#303A3A"), Dot::Attr("fontcolor", "#ffffff")},
n->Name());
} else if (n->IsVar()) {
auto label = n->Name();
if (n->Var() && n->Var()->GetType() == proto::VarType::LOD_TENSOR) {
auto shape = n->Var()->GetShape();
std::vector<std::string> shape_str(shape.size());
std::transform(shape.begin(), shape.end(), shape_str.begin(),
[](const auto& val) { return std::to_string(val); });
label += "\n" + string::join_strings(shape_str, ',');
}
dot.AddNode(
node_id,
{Dot::Attr("shape", "box"), Dot::Attr("style", "rounded,filled,bold"),
Dot::Attr("color", n->Var()->IsParameter() ? "#148b97" : "#dddddd"),
Dot::Attr("fontcolor",
n->Var()->IsParameter() ? "#ffffff" : "#000000")},
label);
}
node2dot[n] = node_id;
}
// Create edges
for (const Node* n : graph.Nodes()) {
const auto& src_id = node2dot.at(n);
for (auto* out : n->outputs) {
const auto& dest_id = node2dot.at(out);
dot.AddEdge(src_id, dest_id, {});
}
return *cache_[cur_key];
}
return dot.Build();
}
const CinnCompiledObject& CinnCompiler::Compile(
const std::string& compilation_key,
const std::map<std::string, const LoDTensor*>& input_tensors,
const Target& target) {
const auto& graph = FindGraph(compilation_key);
return Compile(graph, input_tensors, target);
std::string CinnCompiler::ReadableKey(const std::string& key) const {
proto::ProgramDesc desc;
desc.ParseFromString(key);
return desc.DebugString();
}
void CinnCompiler::Clear() {
{
AutoWRLock guard{&rwlock_};
graphs_.clear();
cache_.clear();
}
real_compiled_num_ = 0;
}
std::unique_ptr<CinnCompiledObject> CinnCompiler::CompileGraph(
......@@ -107,7 +189,7 @@ std::unique_ptr<CinnCompiledObject> CinnCompiler::CompileGraph(
ProgramPass::Apply(&frontend_program, target, {"Decomposer"});
auto cinn_graph = std::make_shared<::cinn::hlir::framework::Graph>(
frontend_program, target);
VLOG(4) << "The " << real_compiled_num_ << "-th compilation ("
VLOG(4) << "-- The " << real_compiled_num_ << "-th compilation ("
<< target.arch_str() << "), and its related graph:\n"
<< cinn_graph->Visualize();
ApplyPass(cinn_graph.get(), "OpFusion");
......
......@@ -25,6 +25,7 @@
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/paddle2cinn/cinn_cache_key.h"
#include "paddle/fluid/framework/rw_lock.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/platform/macros.h"
......@@ -64,6 +65,12 @@ class CinnCompiler {
const ir::Graph& FindGraph(const std::string& key) const;
std::string VizGraph(const std::string& key) const;
std::string ReadableKey(const std::string& key) const;
void Clear();
std::int64_t real_compiled_num() const { return real_compiled_num_; }
~CinnCompiler() = default;
......@@ -80,6 +87,7 @@ class CinnCompiler {
CinnCacheKey::Hash>
cache_;
std::atomic_int64_t real_compiled_num_{0};
mutable RWLock rwlock_;
DISABLE_COPY_AND_ASSIGN(CinnCompiler);
};
......
......@@ -14,12 +14,20 @@
#include "paddle/fluid/framework/paddle2cinn/cinn_compiler.h"
#include <algorithm>
#include <map>
#include <memory>
#include <ostream>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <vector>
#include "cinn/common/target.h"
#include "gflags/gflags.h"
#include "glog/logging.h"
#include "gtest/gtest.h"
#include "paddle/fluid/framework/ddim.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/framework/lod_tensor.h"
......@@ -29,13 +37,76 @@
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/place.h"
DECLARE_string(allow_cinn_ops);
DECLARE_string(deny_cinn_ops);
namespace paddle {
namespace framework {
namespace paddle2cinn {
using ir::Graph;
using ::cinn::common::Target;
namespace {
template <typename T, typename Alloc = std::allocator<T>>
std::ostream& operator<<(std::ostream& os, const std::vector<T, Alloc>& vec) {
os << "{ ";
for (auto e : vec) {
os << e << " ";
}
os << "}\n";
return os;
}
// Get compilation_key values
std::vector<std::string> GetCompilationKeys(const Graph& graph) {
std::vector<std::string> compilation_keys;
for (auto& node : graph.Nodes()) {
if (node->IsOp() && node->Name() == kCinnLaunchOp) {
compilation_keys.emplace_back(
BOOST_GET_CONST(std::string, node->Op()->GetAttr(kCompilationKey)));
}
}
return compilation_keys;
}
// Extract op types from a graph
std::unordered_set<std::string> ExtractOpTypes(const Graph& graph) {
std::unordered_set<std::string> op_types;
for (auto& node : graph.Nodes()) {
if (node->IsOp()) {
op_types.emplace(node->Name());
}
}
return op_types;
}
// Get inputs info
std::unordered_map<std::string, std::vector<int64_t>> GetInputsInfo(
const std::string& key, const Graph& graph) {
std::unordered_set<std::string> inputs;
for (auto& node : graph.Nodes()) {
if (node->IsOp() && node->Name() == kCinnLaunchOp) {
if (BOOST_GET_CONST(std::string, node->Op()->GetAttr(kCompilationKey)) !=
key) {
continue;
}
for (auto in_var_name : node->Op()->InputArgumentNames()) {
VLOG(4) << "get an input name: " << in_var_name;
inputs.emplace(in_var_name);
}
}
}
std::unordered_map<std::string, std::vector<int64_t>> inputs_info;
for (auto& node : graph.Nodes()) {
if (node->IsVar() && inputs.count(node->Name())) {
VLOG(4) << node->Name() << " : " << node->Var()->GetShape();
inputs_info.emplace(node->Name(), node->Var()->GetShape());
}
}
return inputs_info;
}
// X -
// | -> mul -> MUL_OUT -
// Y - | -> elementwise_add -> ADD_OUT -> relu -> RELU_OUT
......@@ -65,6 +136,9 @@ std::unique_ptr<Graph> CreateGraph() {
auto* mul_out = global_block->Var("MUL_OUT");
mul_out->SetType(proto::VarType::LOD_TENSOR);
mul_out->SetLoDLevel(0);
mul_out->SetDataType(proto::VarType::FP32);
mul_out->SetShape({1000, 100});
mul_op->SetOutput("Out", {mul_out->Name()});
// add
......@@ -83,6 +157,9 @@ std::unique_ptr<Graph> CreateGraph() {
auto* add_out = global_block->Var("ADD_OUT");
add_out->SetType(proto::VarType::LOD_TENSOR);
add_out->SetLoDLevel(0);
add_out->SetDataType(proto::VarType::FP32);
add_out->SetShape({1000, 100});
add_op->SetOutput("Out", {add_out->Name()});
// relu
......@@ -92,11 +169,59 @@ std::unique_ptr<Graph> CreateGraph() {
auto* relu_out = global_block->Var("RELU_OUT");
relu_out->SetType(proto::VarType::LOD_TENSOR);
relu_out->SetLoDLevel(0);
relu_out->SetDataType(proto::VarType::FP32);
relu_out->SetShape({1000, 100});
relu_op->SetOutput("Out", {relu_out->Name()});
program.Flush();
return std::make_unique<Graph>(program);
}
} // namespace
TEST(CinnCompilerTest, FlagController) {
// init
auto* cinn_compiler = CinnCompiler::GetInstance();
auto cinn_pass = ir::PassRegistry::Instance().Get("build_cinn_pass");
// apply build_cinn_pass & FLAGS_allow_cinn_ops="add"
{
FLAGS_allow_cinn_ops = "add";
auto graph = CreateGraph();
cinn_compiler->Clear();
cinn_pass->Apply(graph.get());
auto compilation_keys = GetCompilationKeys(*graph);
ASSERT_EQ(compilation_keys.size(), 0);
}
// apply build_cinn_pass & FLAGS_allow_cinn_ops="mul;relu"
{
FLAGS_allow_cinn_ops = "mul;relu";
auto graph = CreateGraph();
cinn_compiler->Clear();
cinn_pass->Apply(graph.get());
auto compilation_keys = GetCompilationKeys(*graph);
ASSERT_EQ(compilation_keys.size(), 2);
}
// apply build_cinn_pass & FLAGS_allow_cinn_ops="" &
// FLAGS_deny_cinn_ops="relu"
{
FLAGS_allow_cinn_ops = "";
FLAGS_deny_cinn_ops = "elementwise_add;relu";
auto graph = CreateGraph();
cinn_compiler->Clear();
cinn_pass->Apply(graph.get());
auto compilation_keys = GetCompilationKeys(*graph);
ASSERT_EQ(compilation_keys.size(), 1);
const auto& compiling_graph = cinn_compiler->FindGraph(compilation_keys[0]);
auto op_types = ExtractOpTypes(compiling_graph);
ASSERT_EQ(op_types.size(), 2);
ASSERT_EQ(op_types.count("feed"), 1);
ASSERT_EQ(op_types.count("mul"), 1);
}
// recover flags
FLAGS_allow_cinn_ops = "";
FLAGS_deny_cinn_ops = "";
}
TEST(CinnCompilerTest, Compile) {
auto viz_pass = ir::PassRegistry::Instance().Get("graph_viz_pass");
auto cinn_pass = ir::PassRegistry::Instance().Get("build_cinn_pass");
......@@ -113,32 +238,31 @@ TEST(CinnCompilerTest, Compile) {
cinn_pass->Apply(graph.get());
viz_graph("processed_graph.dot", graph.get());
// get the compilation_key
std::vector<std::string> compilation_keys;
for (auto& node : graph->Nodes()) {
if (node->IsOp() && node->Name() == kCinnLaunchOp) {
compilation_keys.emplace_back(
BOOST_GET_CONST(std::string, node->Op()->GetAttr(kCompilationKey)));
}
}
auto compilation_keys = GetCompilationKeys(*graph);
ASSERT_EQ(compilation_keys.size(), 1);
const auto& compilation_key = compilation_keys[0];
auto* cinn_compiler = CinnCompiler::GetInstance();
VLOG(4) << "The graph to be compiled:\n"
<< cinn_compiler->VizGraph(compilation_key);
const auto& compiling_graph = cinn_compiler->FindGraph(compilation_key);
// viz_graph("compiling_graph.dot", const_cast<Graph*>(&compiling_graph));
viz_graph("compiling_graph.dot", const_cast<Graph*>(&compiling_graph));
EXPECT_THROW(cinn_compiler->FindGraph("no_existed"),
paddle::platform::EnforceNotMet);
LoDTensor tensor1, tensor2, tensor3;
tensor1.Resize({1000, 784});
tensor2.Resize({784, 100});
tensor3.Resize({100});
tensor1.mutable_data<float>(platform::CPUPlace());
tensor2.mutable_data<float>(platform::CPUPlace());
tensor3.mutable_data<float>(platform::CPUPlace());
std::map<std::string, const LoDTensor*> input_tensors = {
{"X", &tensor1}, {"Y", &tensor2}, {"Z", &tensor3}};
auto inputs_info = GetInputsInfo(compilation_key, *graph);
std::unordered_map<std::string, LoDTensor> create_inputs;
for (const auto& pair : inputs_info) {
auto& tensor = create_inputs[pair.first];
tensor.Resize(make_ddim(pair.second));
tensor.mutable_data<float>(platform::CPUPlace());
}
std::map<std::string, const LoDTensor*> input_tensors;
std::for_each(create_inputs.begin(), create_inputs.end(),
[&input_tensors](const auto& val) {
input_tensors.emplace(val.first, &val.second);
});
auto compile_fn = [&](const Target& target) {
const auto& compiled_obj =
......
......@@ -13,7 +13,6 @@
// limitations under the License.
#include "paddle/fluid/operators/cinn_launch_op.h"
#include "cinn/frontend/var_type_utils.h"
#include "paddle/fluid/string/string_helper.h"
namespace paddle {
......
......@@ -98,7 +98,8 @@ class CinnLaunchOpKernel : public framework::OpKernel<T> {
const auto& compilation_key =
ctx.template Attr<std::string>(kCompilationKey);
VLOG(4) << "CinnLaunchOp attribute(" << kCompilationKey << ") "
<< "value:" << compilation_key;
<< "value:\n"
<< CinnCompiler::GetInstance()->ReadableKey(compilation_key);
const auto& graph = CinnCompiler::GetInstance()->FindGraph(compilation_key);
auto input_variable_names = ctx.InputNames(kX);
......
......@@ -710,6 +710,7 @@ PADDLE_DEFINE_EXPORTED_bool(allreduce_record_one_event, false,
"events. Currently, only fuse allreduce supports "
"this. Otherwise, the precision may be wrong.");
#ifdef PADDLE_WITH_CINN
/*
* CINN related FLAG
* Name: FLAGS_use_cinn
......@@ -717,9 +718,31 @@ PADDLE_DEFINE_EXPORTED_bool(allreduce_record_one_event, false,
* Value Range: bool, default=false
* Example: FLAGS_use_cinn=true would run PaddlePaddle using CINN
*/
#ifdef PADDLE_WITH_CINN
PADDLE_DEFINE_EXPORTED_bool(
use_cinn, false, "It controls whether to run PaddlePaddle using CINN");
/*
* CINN related FLAG
* Name: FLAGS_allow_cinn_ops
* Since Version: 2.3
* Value Range: string, default=""
* Example: FLAGS_allow_cinn_ops="mul;relu" would only cover `mul` and `relu`
* when using CINN
*/
PADDLE_DEFINE_EXPORTED_string(allow_cinn_ops, "",
"It controls the cinn op subset to be used, "
"which has the highest priority.");
/*
* CINN related FLAG
* Name: FLAGS_deny_cinn_ops
* Since Version: 2.3
* Value Range: string, default=""
* Example: FLAGS_deny_cinn_ops="mul;relu" would block `mul` and `relu` two ops
* when using CINN
*/
PADDLE_DEFINE_EXPORTED_string(deny_cinn_ops, "",
"It controls the cinn op subset to be not used.");
#endif
DEFINE_int32(record_pool_max_size, 2000000,
......
......@@ -40,9 +40,9 @@ def set_cinn_flag(val):
def reader(limit):
for i in range(limit):
yield np.ones([1, 28]).astype('float32') * (i * 3.14 / (i + 1)), \
np.array([i + 1]).astype('int64')
for _ in range(limit):
yield np.random.random([1, 28]).astype('float32'), \
np.random.randint(0, 2, size=[1]).astype('int64')
def rand_data(img, label, loop_num=10):
......@@ -62,7 +62,7 @@ def build_program(main_program, startup_program):
shape=[1, 28],
dtype="float32",
attr=paddle.ParamAttr(initializer=paddle.nn.initializer.Assign(
np.ones([1, 28]).astype(np.float32))))
np.random.rand(1, 28).astype(np.float32))))
label = paddle.static.data(name="label", shape=[1], dtype='int64')
hidden = paddle.add(img, param)
......@@ -75,7 +75,12 @@ def build_program(main_program, startup_program):
return img, label, avg_loss
def do_test(dot_save_dir):
def train(dot_save_dir, prefix, seed=1234):
np.random.seed(seed)
paddle.seed(seed)
if paddle.is_compiled_with_cuda():
paddle.set_flags({'FLAGS_cudnn_deterministic': 1})
startup_program = paddle.static.Program()
main_program = paddle.static.Program()
img, label, loss = build_program(main_program, startup_program)
......@@ -86,32 +91,35 @@ def do_test(dot_save_dir):
exe.run(startup_program)
build_strategy = paddle.static.BuildStrategy()
build_strategy.debug_graphviz_path = os.path.join(dot_save_dir, "viz")
build_strategy.debug_graphviz_path = os.path.join(dot_save_dir, prefix)
compiled_program = paddle.static.CompiledProgram(
main_program, build_strategy).with_data_parallel(loss_name=loss.name)
iters = 1
iters = 100
feed = rand_data(img.name, label.name, iters)
loss_values = []
for step in range(iters):
loss_v = exe.run(compiled_program,
feed=feed[step],
fetch_list=[loss],
return_merged=False)
logger.info("loss value = {}".format(loss_v))
loss_values.append(loss_v[0][0][0])
return loss_values
@unittest.skipIf(not set_cinn_flag(True), "Paddle is not compiled with CINN.")
class TestParallelExecutorRunCinn(unittest.TestCase):
def setUp(self):
set_cinn_flag(True)
self.tmpdir = tempfile.mkdtemp(prefix="dots_")
def tearDown(self):
set_cinn_flag(False)
shutil.rmtree(self.tmpdir)
def test_run_with_cinn(self):
do_test(self.tmpdir)
cinn_losses = train(self.tmpdir, "paddle")
set_cinn_flag(False)
pd_losses = train(self.tmpdir, "cinn")
self.assertTrue(np.allclose(cinn_losses, pd_losses, atol=1e-5))
if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册