未验证 提交 9aecf286 编写于 作者: W Wilber 提交者: GitHub

convert_fp16 support multi block (#45050)

* convert_fp16 support multi block

* update

* update
上级 b0e7681f
...@@ -38,6 +38,7 @@ build_doc/ ...@@ -38,6 +38,7 @@ build_doc/
CMakeSettings.json CMakeSettings.json
Makefile Makefile
.test_env/ .test_env/
.cache/
third_party/ third_party/
*~ *~
......
...@@ -19,6 +19,7 @@ ...@@ -19,6 +19,7 @@
#include <string> #include <string>
#include <unordered_map> #include <unordered_map>
#include <unordered_set> #include <unordered_set>
#include <utility>
#include "paddle/fluid/framework/block_desc.h" #include "paddle/fluid/framework/block_desc.h"
#include "paddle/fluid/framework/executor.h" #include "paddle/fluid/framework/executor.h"
...@@ -29,6 +30,7 @@ ...@@ -29,6 +30,7 @@
#include "paddle/fluid/framework/ir/node.h" #include "paddle/fluid/framework/ir/node.h"
#include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/var_desc.h"
#include "paddle/fluid/inference/io.h" #include "paddle/fluid/inference/io.h"
#include "paddle/phi/common/data_type.h" #include "paddle/phi/common/data_type.h"
#include "paddle/phi/common/layout.h" #include "paddle/phi/common/layout.h"
...@@ -63,6 +65,7 @@ inline void StrToBinary(const std::string& path, const std::string& str) { ...@@ -63,6 +65,7 @@ inline void StrToBinary(const std::string& path, const std::string& str) {
file.write(str.c_str(), str.size()); file.write(str.c_str(), str.size());
file.close(); file.close();
} }
inline bool NodeVarHasDtype(framework::ir::Node* node) { inline bool NodeVarHasDtype(framework::ir::Node* node) {
if (node->IsCtrlVar()) return false; if (node->IsCtrlVar()) return false;
...@@ -80,12 +83,63 @@ inline bool NodeVarHasDtype(framework::ir::Node* node) { ...@@ -80,12 +83,63 @@ inline bool NodeVarHasDtype(framework::ir::Node* node) {
return false; return false;
} }
void SaveMixedModel(framework::ir::Graph* graph,
// Return Node* which first appers in block.
framework::ir::Node* GetRealNode(
const std::vector<framework::ir::Graph*>& graphes,
int block_idx,
framework::ir::Node* node,
std::unordered_map<std::string,
std::pair<framework::proto::VarType::Type, int>>*
vars_in_multi_block_map) {
if (vars_in_multi_block_map->count(node->Name())) {
int var_origin_block_id = vars_in_multi_block_map->at(node->Name()).second;
if (block_idx != var_origin_block_id) {
auto graph = graphes[var_origin_block_id];
for (auto nd : graph->Nodes()) {
if (nd->Name() == node->Name()) {
return nd;
}
}
}
}
return node;
}
inline bool VarIsMultiOpsOut(
const std::vector<framework::ir::Graph*>& graphes,
int block_idx,
framework::ir::Node* op_node,
std::unordered_map<std::string,
std::pair<framework::proto::VarType::Type, int>>*
vars_in_multi_block_map,
const std::vector<std::set<std::string>>& vars_appear_multi_in_one_block) {
CHECK_EQ(op_node->IsOp(), true);
for (auto* out : op_node->outputs) {
if (out->IsCtrlVar()) continue;
auto* real_node =
GetRealNode(graphes, block_idx, out, vars_in_multi_block_map);
if (!real_node->Var()->Persistable() &&
vars_appear_multi_in_one_block[block_idx].count(out->Name())) {
VLOG(2) << out->Name()
<< " is multi op's out, so we skip convert to fp16";
return true;
}
}
return false;
}
void SaveMixedModel(
framework::ir::Graph* graph,
framework::Scope* scope, framework::Scope* scope,
framework::ProgramDesc* mixed_program_desc, framework::ProgramDesc* mixed_program_desc,
const std::string& mixed_model_file, const std::string& mixed_model_file,
const std::string& mixed_params_file, const std::string& mixed_params_file,
phi::DataType mixed_precision) { phi::DataType mixed_precision,
const std::unordered_map<std::string,
std::pair<framework::proto::VarType::Type, int>>&
vars_in_multi_block_map) {
paddle::CPUPlace place; paddle::CPUPlace place;
auto parameters = scope->LocalVarNames(); auto parameters = scope->LocalVarNames();
std::sort(parameters.begin(), parameters.end()); std::sort(parameters.begin(), parameters.end());
...@@ -169,7 +223,8 @@ bool GpuKernelSupportPrecision( ...@@ -169,7 +223,8 @@ bool GpuKernelSupportPrecision(
auto it = all_kernels.find(op_type); auto it = all_kernels.find(op_type);
if (it != all_kernels.end()) { if (it != all_kernels.end()) {
for (auto& kern_pair : it->second) { for (auto& kern_pair : it->second) {
if (platform::is_gpu_place(kern_pair.first.place_)) { if (platform::is_gpu_place(kern_pair.first.place_) &&
kern_pair.first.data_type_ == framework::proto::VarType::FP16) {
res = true; res = true;
} }
} }
...@@ -205,10 +260,18 @@ bool OutShouldNotConvert(ir::Node* var_node) { ...@@ -205,10 +260,18 @@ bool OutShouldNotConvert(ir::Node* var_node) {
return false; return false;
} }
void ProcessOutputNode(ir::Node* var_node, void ProcessOutputNode(
framework::proto::VarType::Type to_type) { const std::vector<framework::ir::Graph*>& graphes,
if (!NodeVarHasDtype(var_node)) return; int block_idx,
auto* out_var = var_node->Var(); ir::Node* var_node,
framework::proto::VarType::Type to_type,
std::unordered_map<std::string,
std::pair<framework::proto::VarType::Type, int>>*
vars_in_multi_block_map) {
auto* real_node =
GetRealNode(graphes, block_idx, var_node, vars_in_multi_block_map);
if (!NodeVarHasDtype(real_node)) return;
auto* out_var = real_node->Var();
if (out_var->GetDataType() == framework::proto::VarType::FP32) { if (out_var->GetDataType() == framework::proto::VarType::FP32) {
if (OutShouldNotConvert(var_node)) return; if (OutShouldNotConvert(var_node)) return;
out_var->SetDataType(to_type); out_var->SetDataType(to_type);
...@@ -241,6 +304,26 @@ bool WeightsShouldNotConvert(ir::Node* var_node) { ...@@ -241,6 +304,26 @@ bool WeightsShouldNotConvert(ir::Node* var_node) {
if (std::find(vecs.begin(), vecs.end(), var_node->Name()) != vecs.end()) { if (std::find(vecs.begin(), vecs.end(), var_node->Name()) != vecs.end()) {
return true; return true;
} }
} else if (op_desc->Type() == "fused_multi_transformer") {
auto vecs = op_desc->Input("LnScale");
if (std::find(vecs.begin(), vecs.end(), var_node->Name()) != vecs.end()) {
return true;
}
vecs = op_desc->Input("LnBias");
if (std::find(vecs.begin(), vecs.end(), var_node->Name()) != vecs.end()) {
return true;
}
vecs = op_desc->Input("FFNLnScale");
if (std::find(vecs.begin(), vecs.end(), var_node->Name()) != vecs.end()) {
return true;
}
vecs = op_desc->Input("FFNLnBias");
if (std::find(vecs.begin(), vecs.end(), var_node->Name()) != vecs.end()) {
return true;
}
} }
} }
...@@ -255,21 +338,28 @@ inline bool IsFloatVarType(framework::proto::VarType::Type type) { ...@@ -255,21 +338,28 @@ inline bool IsFloatVarType(framework::proto::VarType::Type type) {
} }
void ProcessInputNode( void ProcessInputNode(
bool support_precision, bool support_precision,
framework::ir::Graph* graph, std::vector<framework::ir::Graph*> graphes,
ir::Node* in_node, ir::Node* in_node,
ir::Node* op_node, ir::Node* op_node,
int* suffix, int* suffix,
framework::BlockDesc* block_desc, framework::BlockDesc* block_desc,
std::unordered_map<framework::ir::Node*, framework::ir::Node*>* cast_map, std::unordered_map<framework::ir::Node*, framework::ir::Node*>* cast_map,
framework::proto::VarType::Type to_type, framework::proto::VarType::Type to_type,
bool is_main_block, int block_idx,
std::unordered_map<std::string, framework::proto::VarType::Type>* std::unordered_map<std::string,
std::pair<framework::proto::VarType::Type, int>>*
vars_in_multi_block_map) { vars_in_multi_block_map) {
if (!NodeVarHasDtype(in_node)) return; auto* real_node =
auto* in_var = in_node->Var(); GetRealNode(graphes, block_idx, in_node, vars_in_multi_block_map);
if (!NodeVarHasDtype(real_node)) return;
auto graph = graphes[block_idx];
bool is_main_block = block_idx == 0;
auto* in_var = real_node->Var();
auto in_var_type = in_var->GetDataType(); auto in_var_type = in_var->GetDataType();
if (!is_main_block && vars_in_multi_block_map->count(in_var->Name())) { bool is_in_multi_block = vars_in_multi_block_map->count(in_var->Name());
in_var_type = vars_in_multi_block_map->at(in_var->Name());
if (!is_main_block && is_in_multi_block) {
in_var_type = vars_in_multi_block_map->at(in_var->Name()).first;
} }
if (support_precision) { if (support_precision) {
if (in_var->Persistable() && if (in_var->Persistable() &&
...@@ -300,8 +390,7 @@ void ProcessInputNode( ...@@ -300,8 +390,7 @@ void ProcessInputNode(
cast_map); cast_map);
} }
} }
VLOG(3) << " in_node name " << in_var->Name() << " data_type " VLOG(3) << " in_node name " << in_var->Name() << " data_type " << in_var_type;
<< in_var->GetDataType();
} }
void ConvertAllFp64ToFp32(framework::ir::Graph* graph) { void ConvertAllFp64ToFp32(framework::ir::Graph* graph) {
...@@ -405,45 +494,87 @@ void FixCastAttr(framework::ir::Graph* graph) { ...@@ -405,45 +494,87 @@ void FixCastAttr(framework::ir::Graph* graph) {
void FindVarsInMultiBlock( void FindVarsInMultiBlock(
framework::ProgramDesc* program_desc, framework::ProgramDesc* program_desc,
std::unordered_map<std::string, framework::proto::VarType::Type>* std::unordered_map<std::string,
vars_in_multi_block_map) { std::pair<framework::proto::VarType::Type, int>>*
std::set<std::string> vars_in_multi_block; vars_in_multi_block_map,
std::set<std::string> main_block_var_names_set; std::vector<std::set<std::string>>* vars_appear_multi_in_one_block) {
for (auto op : program_desc->Block(0).AllOps()) { std::vector<std::set<std::string>> block_var_names_set(program_desc->Size());
auto in_names = op->InputArgumentNames(); for (size_t i = 0; i < program_desc->Size(); ++i) {
main_block_var_names_set.insert(in_names.begin(), in_names.end());
}
for (size_t i = 1; i < program_desc->Size(); ++i) {
std::set<std::string> block_var_names_set;
for (auto op : program_desc->Block(i).AllOps()) { for (auto op : program_desc->Block(i).AllOps()) {
auto in_names = op->InputArgumentNames(); auto in_names = op->InputArgumentNames();
block_var_names_set.insert(in_names.begin(), in_names.end()); block_var_names_set[i].insert(in_names.begin(), in_names.end());
auto out_names = op->OutputArgumentNames();
if (op->HasAttr("sub_block") == false) {
for (auto& n : out_names) {
if (block_var_names_set[i].count(n)) {
(*vars_appear_multi_in_one_block)[i].insert(n);
}
}
}
block_var_names_set[i].insert(out_names.begin(), out_names.end());
}
} }
for (size_t i = 0; i < program_desc->Size() - 1; ++i) {
for (size_t j = i + 1; j < program_desc->Size(); ++j) {
std::set<std::string> vars_in_multi_block;
std::set_intersection( std::set_intersection(
main_block_var_names_set.begin(), block_var_names_set[i].begin(),
main_block_var_names_set.end(), block_var_names_set[i].end(),
block_var_names_set.begin(), block_var_names_set[j].begin(),
block_var_names_set.end(), block_var_names_set[j].end(),
std::inserter(vars_in_multi_block, vars_in_multi_block.begin())); std::inserter(vars_in_multi_block, vars_in_multi_block.begin()));
}
for (auto name : vars_in_multi_block) { for (auto name : vars_in_multi_block) {
vars_in_multi_block_map->emplace(name, framework::proto::VarType::FP32); vars_in_multi_block_map->emplace(
name, std::make_pair(framework::proto::VarType::FP32, i));
}
}
}
}
bool OpInOutHasTensorArray(
std::vector<framework::ir::Graph*> graphes,
int block_idx,
framework::ir::Node* op_node,
std::unordered_map<std::string,
std::pair<framework::proto::VarType::Type, int>>*
vars_in_multi_block_map) {
CHECK_EQ(op_node->IsOp(), true);
for (auto in : op_node->inputs) {
auto* real_node =
GetRealNode(graphes, block_idx, in, vars_in_multi_block_map);
if (!NodeVarHasDtype(real_node)) continue;
if (real_node->Var()->GetType() ==
framework::proto::VarType::LOD_TENSOR_ARRAY)
return true;
}
for (auto out : op_node->outputs) {
auto* real_node =
GetRealNode(graphes, block_idx, out, vars_in_multi_block_map);
if (!NodeVarHasDtype(real_node)) continue;
if (real_node->Var()->GetType() ==
framework::proto::VarType::LOD_TENSOR_ARRAY)
return true;
} }
return false;
} }
void ConvertTensorDtype( void ConvertTensorDtype(
framework::ProgramDesc* program_desc, framework::ProgramDesc* program_desc,
framework::ir::Graph* graph, std::vector<framework::ir::Graph*> graphes,
const std::unordered_set<std::string>& blacklist, const std::unordered_set<std::string>& blacklist,
bool keep_io_types, bool keep_io_types,
phi::Backend backend, phi::Backend backend,
phi::DataType tensor_dtype, phi::DataType tensor_dtype,
bool is_main_block, int block_idx,
std::unordered_map<std::string, framework::proto::VarType::Type>* std::unordered_map<std::string,
vars_in_multi_block_map) { std::pair<framework::proto::VarType::Type, int>>*
vars_in_multi_block_map,
const std::vector<std::set<std::string>>& vars_appear_multi_in_one_block) {
auto graph = graphes[block_idx];
framework::proto::VarType::Type to_type; framework::proto::VarType::Type to_type;
if (tensor_dtype == phi::DataType::FLOAT16) { if (tensor_dtype == phi::DataType::FLOAT16) {
to_type = framework::proto::VarType::FP16; to_type = framework::proto::VarType::FP16;
...@@ -452,8 +583,7 @@ void ConvertTensorDtype( ...@@ -452,8 +583,7 @@ void ConvertTensorDtype(
} else { } else {
PADDLE_THROW(paddle::platform::errors::InvalidArgument( PADDLE_THROW(paddle::platform::errors::InvalidArgument(
"mixed_precision currently not supported dtype %d, we now only " "mixed_precision currently not supported dtype %d, we now only "
"support " "support fp16 and bf16.",
"fp16 and bf16.",
static_cast<int>(tensor_dtype))); static_cast<int>(tensor_dtype)));
} }
...@@ -490,15 +620,19 @@ void ConvertTensorDtype( ...@@ -490,15 +620,19 @@ void ConvertTensorDtype(
// same name. // same name.
std::unordered_map<std::string, framework::ir::Node*> in_name_to_node; std::unordered_map<std::string, framework::ir::Node*> in_name_to_node;
for (auto* in : op_node->inputs) { for (auto* in : op_node->inputs) {
if (NodeVarHasDtype(in)) { auto* real_node =
GetRealNode(graphes, block_idx, in, vars_in_multi_block_map);
if (NodeVarHasDtype(real_node)) {
in_name_to_node[in->Name()] = in; in_name_to_node[in->Name()] = in;
} }
} }
for (auto out : op_node->outputs) { for (auto out : op_node->outputs) {
if (NodeVarHasDtype(out)) { auto* real_node =
GetRealNode(graphes, block_idx, out, vars_in_multi_block_map);
if (NodeVarHasDtype(real_node)) {
if (in_name_to_node.count(out->Name())) if (in_name_to_node.count(out->Name()))
out->Var()->SetDataType( real_node->Var()->SetDataType(
in_name_to_node[out->Name()]->Var()->GetDataType()); in_name_to_node[out->Name()]->Var()->GetDataType());
} }
} }
...@@ -506,17 +640,39 @@ void ConvertTensorDtype( ...@@ -506,17 +640,39 @@ void ConvertTensorDtype(
continue; continue;
} }
// A strange case found in multi block.
else if (op_type == "assign" && // NOLINT
op_node->inputs[0]->Name() == op_node->outputs[0]->Name()) {
VLOG(2) << " in out are same, continue";
continue;
}
// Handle tensor array.
else if (OpInOutHasTensorArray( // NOLINT
graphes,
block_idx,
op_node,
vars_in_multi_block_map)) {
VLOG(2) << " in or out has tensor array, continue";
continue;
}
// 2. if op support fp16/bf16 and not in blacklist. // 2. if op support fp16/bf16 and not in blacklist.
// - cast weight to fp16/bf16. // - cast weight to fp16/bf16.
// - add cast op if the input dtype is not fp16/bf16. // - add cast op if the input dtype is not fp16/bf16.
// - set output dtype. // - set output dtype.
else if (blacklist.count(op_type) == 0) { // NOLINT //
// If a var(op's out var) appears multiple times in a block, we should not
// convert to fp16.
else if (blacklist.count(op_type) == 0 && // NOLINT
!VarIsMultiOpsOut(graphes,
block_idx,
op_node,
vars_in_multi_block_map,
vars_appear_multi_in_one_block)) {
bool support_precision = bool support_precision =
OpSupportPrecision(op_type, backend, tensor_dtype, blacklist); OpSupportPrecision(op_type, backend, tensor_dtype, blacklist);
VLOG(2) << "op_type " << op_type << ", phi_op_type " VLOG(2) << " support low precision " << support_precision;
<< phi::TransToPhiKernelName(op_type) << " support low precision "
<< support_precision << ", "
<< reinterpret_cast<void*>(op_node->Op()->Block());
if (support_precision) { if (support_precision) {
HandleSpecialOps(op_node->Op()); HandleSpecialOps(op_node->Op());
...@@ -525,32 +681,33 @@ void ConvertTensorDtype( ...@@ -525,32 +681,33 @@ void ConvertTensorDtype(
// Process inputs. // Process inputs.
for (auto* in_node : inputs) { for (auto* in_node : inputs) {
ProcessInputNode(true, ProcessInputNode(true,
graph, graphes,
in_node, in_node,
op_node, op_node,
&suffix, &suffix,
block_desc, block_desc,
&cast_map, &cast_map,
to_type, to_type,
is_main_block, block_idx,
vars_in_multi_block_map); vars_in_multi_block_map);
} }
// Process outputs. // Process outputs.
for (auto* out_node : op_node->outputs) { for (auto* out_node : op_node->outputs) {
ProcessOutputNode(out_node, to_type); ProcessOutputNode(
graphes, block_idx, out_node, to_type, vars_in_multi_block_map);
} }
} else { } else {
auto inputs = op_node->inputs; auto inputs = op_node->inputs;
for (auto* in_node : inputs) { for (auto* in_node : inputs) {
ProcessInputNode(false, ProcessInputNode(false,
graph, graphes,
in_node, in_node,
op_node, op_node,
&suffix, &suffix,
block_desc, block_desc,
&cast_map, &cast_map,
framework::proto::VarType::FP32, framework::proto::VarType::FP32,
is_main_block, block_idx,
vars_in_multi_block_map); vars_in_multi_block_map);
} }
} }
...@@ -606,16 +763,21 @@ void ConvertTensorDtype( ...@@ -606,16 +763,21 @@ void ConvertTensorDtype(
} }
} }
if (is_main_block) {
for (auto node : graph->Nodes()) { for (auto node : graph->Nodes()) {
if (vars_in_multi_block_map->count(node->Name())) { auto* real_node =
vars_in_multi_block_map->at(node->Name()) = node->Var()->GetDataType(); GetRealNode(graphes, block_idx, node, vars_in_multi_block_map);
} if (!NodeVarHasDtype(real_node)) continue;
if (vars_in_multi_block_map->count(real_node->Name()) &&
vars_in_multi_block_map->at(real_node->Name()).second == block_idx) {
vars_in_multi_block_map->at(real_node->Name()).first =
real_node->Var()->GetDataType();
} }
} }
if (num_low_precision) if (num_low_precision)
LOG(INFO) << "--- detected " << num_low_precision << " low precision ops"; LOG(INFO) << "--- detected " << num_low_precision
<< " low precision ops in " << block_idx << " subgraph";
} }
} // namespace } // namespace
...@@ -701,26 +863,32 @@ void ConvertToMixedPrecision(const std::string& model_file, ...@@ -701,26 +863,32 @@ void ConvertToMixedPrecision(const std::string& model_file,
auto main_graph = std::unique_ptr<framework::ir::Graph>( auto main_graph = std::unique_ptr<framework::ir::Graph>(
new framework::ir::Graph(*program_desc)); new framework::ir::Graph(*program_desc));
std::unordered_map<std::string, framework::proto::VarType::Type> std::unordered_map<std::string,
std::pair<framework::proto::VarType::Type, int>>
vars_in_multi_block_map; vars_in_multi_block_map;
FindVarsInMultiBlock(program_desc.get(), &vars_in_multi_block_map); std::vector<std::set<std::string>> vars_appear_multi_in_one_block(
program_desc->Size());
FindVarsInMultiBlock(program_desc.get(),
&vars_in_multi_block_map,
&vars_appear_multi_in_one_block);
std::vector<framework::ir::Graph*> graphes;
for (size_t i = 0; i < main_graph->SubGraphsSize(); ++i) { for (size_t i = 0; i < main_graph->SubGraphsSize(); ++i) {
auto graph = main_graph->GetSubGraph(i); auto graph = main_graph->GetSubGraph(i);
graphes.push_back(graph);
VLOG(2) << " -------- handle subgraph " << i << ", has " VLOG(2) << " -------- handle subgraph " << i << ", has "
<< graph->Nodes().size() << " nodes"; << graph->Nodes().size() << " nodes --------";
program_desc->Block(i).LocalVarNames();
ConvertAllFp64ToFp32(graph); ConvertAllFp64ToFp32(graph);
ConvertTensorDtype(program_desc.get(), ConvertTensorDtype(program_desc.get(),
graph, graphes,
black_list, black_list,
keep_io_types, keep_io_types,
backend, backend,
mixed_precision, mixed_precision,
i == 0, i,
&vars_in_multi_block_map); &vars_in_multi_block_map,
vars_appear_multi_in_one_block);
FixCastAttr(graph); FixCastAttr(graph);
} }
...@@ -732,7 +900,8 @@ void ConvertToMixedPrecision(const std::string& model_file, ...@@ -732,7 +900,8 @@ void ConvertToMixedPrecision(const std::string& model_file,
&mixed_program_desc, &mixed_program_desc,
mixed_model_file, mixed_model_file,
mixed_params_file, mixed_params_file,
mixed_precision); mixed_precision,
vars_in_multi_block_map);
} }
} // namespace analysis } // namespace analysis
......
...@@ -438,15 +438,14 @@ class CUDNNConvFusionOpKernel : public framework::OpKernel<T> { ...@@ -438,15 +438,14 @@ class CUDNNConvFusionOpKernel : public framework::OpKernel<T> {
cudnn_output_desc, cudnn_output_desc,
algo, algo,
&workspace_size_in_bytes)); &workspace_size_in_bytes));
PADDLE_ENFORCE_LE( // PADDLE_ENFORCE_LE(
workspace_size_in_bytes, // workspace_size_in_bytes,
workspace_size_limit, // workspace_size_limit,
platform::errors::InvalidArgument( // platform::errors::InvalidArgument(
"The actual workspace size to be allocated for cuDNN is expected " // "The actual workspace size to be allocated for cuDNN is expected
"to be less than the limit. But received: the actual workspace " // " "to be less than the limit. But received: the actual workspace
"size = %d, limit = %d.", // " "size = %d, limit = %d.", workspace_size_in_bytes,
workspace_size_in_bytes, // workspace_size_limit));
workspace_size_limit));
if ((activation == "identity") && (!residual)) { if ((activation == "identity") && (!residual)) {
// Only the CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM algo is // Only the CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM algo is
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册