未验证 提交 80d35725 编写于 作者: Z Zhaolong Xing 提交者: GitHub

align yolov3 cuda int8 (#2183)

test=develop
上级 56151776
...@@ -537,16 +537,16 @@ bool CudnnConv2DInt8<Ptype_out>::run(const operators::ConvParam& param) { ...@@ -537,16 +537,16 @@ bool CudnnConv2DInt8<Ptype_out>::run(const operators::ConvParam& param) {
static_cast<const void*>(scale), static_cast<const void*>(scale),
this->stream_); this->stream_);
} else { } else {
bias_int8_nhwc<int8_t>(num, bias_int8_nhwc<float>(num,
static_cast<const void*>(temp_out), static_cast<const void*>(temp_out),
static_cast<const void*>(b_data), static_cast<const void*>(b_data),
static_cast<void*>(temp_out), static_cast<void*>(temp_out),
n, n,
c, c,
h, h,
w, w,
static_cast<const void*>(scale), static_cast<const void*>(scale),
this->stream_); this->stream_);
} }
return true; return true;
} }
......
...@@ -30,17 +30,17 @@ void TypeLayoutTransformPass::Apply(const std::unique_ptr<SSAGraph>& graph) { ...@@ -30,17 +30,17 @@ void TypeLayoutTransformPass::Apply(const std::unique_ptr<SSAGraph>& graph) {
// Start from inputs of the graph, those should have place set. // Start from inputs of the graph, those should have place set.
VLOG(4) << "\n" << Visualize(graph.get()); VLOG(4) << "\n" << Visualize(graph.get());
std::list<Node*> nodes; std::list<Node*> nodes;
for (auto& node : graph->mutable_nodes()) { for (auto& node : graph->StmtTopologicalOrder()) {
nodes.push_back(&node); nodes.push_back(node);
} }
LOG(INFO) << "nodes.size():" << nodes.size(); VLOG(4) << "nodes.size():" << nodes.size();
for (auto& node : nodes) { for (auto& node : nodes) {
LOG(INFO) << "!node->IsStmt():" << !node->IsStmt(); VLOG(4) << "!node->IsStmt():" << !node->IsStmt();
if (!node->IsStmt()) continue; if (!node->IsStmt()) continue;
auto inlinks = node->inlinks; auto inlinks = node->inlinks;
LOG(INFO) << "node->AsStmt().desc:" << node->AsStmt().desc VLOG(4) << "node->AsStmt().desc:" << node->AsStmt().desc
<< " inlinks.size():" << inlinks.size(); << " inlinks.size():" << inlinks.size();
for (auto* in : inlinks) { for (auto* in : inlinks) {
ComplementInputs(graph.get(), node, in); ComplementInputs(graph.get(), node, in);
} }
...@@ -58,7 +58,7 @@ void TypeLayoutTransformPass::ComplementInputs(SSAGraph* graph, ...@@ -58,7 +58,7 @@ void TypeLayoutTransformPass::ComplementInputs(SSAGraph* graph,
CHECK(inst_node->IsStmt()); CHECK(inst_node->IsStmt());
auto& inst = inst_node->AsStmt(); auto& inst = inst_node->AsStmt();
LOG(INFO) << "found Target tensor: " << in->AsArg().name; VLOG(4) << "found Target tensor: " << in->AsArg().name;
CHECK(in->IsRoleSet()); CHECK(in->IsRoleSet());
CHECK(in->IsArg()); CHECK(in->IsArg());
auto in_arg_name = in->AsArg().name; auto in_arg_name = in->AsArg().name;
...@@ -66,15 +66,15 @@ void TypeLayoutTransformPass::ComplementInputs(SSAGraph* graph, ...@@ -66,15 +66,15 @@ void TypeLayoutTransformPass::ComplementInputs(SSAGraph* graph,
CHECK(inst.op_info()->GetInputArgname(in_arg_name, &tmp)); CHECK(inst.op_info()->GetInputArgname(in_arg_name, &tmp));
auto decl_arg_type = inst.picked_kernel().GetInputDeclType(tmp); auto decl_arg_type = inst.picked_kernel().GetInputDeclType(tmp);
CHECK(in->AsArg().type); CHECK(in->AsArg().type);
LOG(INFO) << "\n tmp:" << tmp << "\n in->AsArg().name:" << in->AsArg().name VLOG(4) << "\n tmp:" << tmp << "\n in->AsArg().name:" << in->AsArg().name
<< "\n *in->AsArg().type:" << *in->AsArg().type << "\n *in->AsArg().type:" << *in->AsArg().type
<< "\n *decl_arg_type:" << *decl_arg_type << "\n *decl_arg_type:" << *decl_arg_type
<< "\n inst.op()->DebugString():" << inst.op()->DebugString(); << "\n inst.op()->DebugString():" << inst.op()->DebugString();
if (!DataLayoutCompatible(*in->AsArg().type, *decl_arg_type)) { if (!DataLayoutCompatible(*in->AsArg().type, *decl_arg_type)) {
LOG(INFO) << "found Layout unmatched tensor: " << in->AsArg().name VLOG(4) << "found Layout unmatched tensor: " << in->AsArg().name
<< " for kernel " << inst.op()->DebugString() << " " << " for kernel " << inst.op()->DebugString() << " "
<< *in->AsArg().type << " -> " << *decl_arg_type; << *in->AsArg().type << " -> " << *decl_arg_type;
AddLayoutInst(*in->AsArg().type, AddLayoutInst(*in->AsArg().type,
*decl_arg_type, *decl_arg_type,
in, in,
...@@ -94,9 +94,9 @@ void TypeLayoutTransformPass::AddLayoutInst( ...@@ -94,9 +94,9 @@ void TypeLayoutTransformPass::AddLayoutInst(
CHECK(!valid_places.empty()) << "valid_place should be set"; CHECK(!valid_places.empty()) << "valid_place should be set";
CHECK(in->IsArg()); CHECK(in->IsArg());
auto node_id = [&] { return graph->nodes().size(); }; // auto node_id = [&] { return graph->nodes().size(); };
auto layout_output_name = auto layout_output_name =
string_format("%s/layout_trans/%d", in->AsArg().name.c_str(), node_id()); string_format("%s/layout_trans", in->AsArg().name.c_str());
auto* layout_output_arg = graph->NewArgumentNode(layout_output_name); auto* layout_output_arg = graph->NewArgumentNode(layout_output_name);
layout_output_arg->AsArg().type = layout_output_arg->AsArg().type =
LiteType::GetTensorTy(from.target(), from.precision(), to.layout()); LiteType::GetTensorTy(from.target(), from.precision(), to.layout());
...@@ -145,10 +145,10 @@ void TypeLayoutTransformPass::AddLayoutInst( ...@@ -145,10 +145,10 @@ void TypeLayoutTransformPass::AddLayoutInst(
CHECK(is_found) << "Can't find a layout kernel for layout op: " << from CHECK(is_found) << "Can't find a layout kernel for layout op: " << from
<< ":" << in->AsArg().name << "->" << to << ":" << ":" << in->AsArg().name << "->" << to << ":"
<< inst_node->AsStmt().op_info()->Type(); << inst_node->AsStmt().op_info()->Type();
LOG(INFO) << "========= final picked kernel [info]:" VLOG(4) << "========= final picked kernel [info]:"
<< layout_inst->AsStmt().picked_kernel().name() << layout_inst->AsStmt().picked_kernel().name()
<< " [summary]:" << layout_inst->AsStmt().picked_kernel().summary() << " [summary]:" << layout_inst->AsStmt().picked_kernel().summary()
<< "\n"; << "\n";
// Remove the old link // Remove the old link
RemoveDirectedLink(in, inst_node); RemoveDirectedLink(in, inst_node);
......
...@@ -28,8 +28,8 @@ namespace mir { ...@@ -28,8 +28,8 @@ namespace mir {
void PrecisionCastPass::Apply(const std::unique_ptr<SSAGraph>& graph) { void PrecisionCastPass::Apply(const std::unique_ptr<SSAGraph>& graph) {
// Start from inputs of the graph, those should have place set. // Start from inputs of the graph, those should have place set.
std::list<Node*> nodes; std::list<Node*> nodes;
for (auto& node : graph->mutable_nodes()) { for (auto& node : graph->StmtTopologicalOrder()) {
nodes.push_back(&node); nodes.push_back(node);
} }
for (auto& node : nodes) { for (auto& node : nodes) {
...@@ -86,9 +86,9 @@ void PrecisionCastPass::AddCastInst(const Type& from, ...@@ -86,9 +86,9 @@ void PrecisionCastPass::AddCastInst(const Type& from,
// var -> new_transform_op -> new_var -> inst // var -> new_transform_op -> new_var -> inst
// So there will be a new Argument node and a new Cast Statement Node. // So there will be a new Argument node and a new Cast Statement Node.
CHECK(in->IsArg()); CHECK(in->IsArg());
auto node_id = [&] { return graph->nodes().size(); }; // auto node_id = [&] { return graph->nodes().size(); };
auto cast_op_output_name = auto cast_op_output_name = in->AsArg().name + "/precision_trans";
in->AsArg().name + "/precision_trans/" + std::to_string(node_id()); // in->AsArg().name + "/precision_trans/" + std::to_string(node_id());
auto* cast_op_output_arg = graph->NewArgumentNode(cast_op_output_name); auto* cast_op_output_arg = graph->NewArgumentNode(cast_op_output_name);
cast_op_output_arg->AsArg().type = cast_op_output_arg->AsArg().type =
LiteType::GetTensorTy(from.target(), to.precision(), from.layout()); LiteType::GetTensorTy(from.target(), to.precision(), from.layout());
......
...@@ -29,8 +29,8 @@ namespace mir { ...@@ -29,8 +29,8 @@ namespace mir {
void TypeTargetTransformPass::Apply(const std::unique_ptr<SSAGraph>& graph) { void TypeTargetTransformPass::Apply(const std::unique_ptr<SSAGraph>& graph) {
// Start from inputs of the graph, those should have place set. // Start from inputs of the graph, those should have place set.
std::list<Node*> nodes; std::list<Node*> nodes;
for (auto& node : graph->mutable_nodes()) { for (auto& node : graph->StmtTopologicalOrder()) {
nodes.push_back(&node); nodes.push_back(node);
} }
CHECK(!valid_places_.empty()); CHECK(!valid_places_.empty());
...@@ -60,7 +60,6 @@ void TypeTargetTransformPass::ComplementInputs(SSAGraph* graph, ...@@ -60,7 +60,6 @@ void TypeTargetTransformPass::ComplementInputs(SSAGraph* graph,
auto in_arg_name = in->AsArg().name; auto in_arg_name = in->AsArg().name;
std::string tmp; std::string tmp;
CHECK(inst.op_info()->GetInputArgname(in_arg_name, &tmp)); CHECK(inst.op_info()->GetInputArgname(in_arg_name, &tmp));
LOG(INFO) << "tmp:" << tmp;
auto decl_arg_type = inst.picked_kernel().GetInputDeclType(tmp); auto decl_arg_type = inst.picked_kernel().GetInputDeclType(tmp);
CHECK(in->AsArg().type); CHECK(in->AsArg().type);
if (!TargetCompatibleTo(*in->AsArg().type, *decl_arg_type)) { if (!TargetCompatibleTo(*in->AsArg().type, *decl_arg_type)) {
...@@ -85,9 +84,10 @@ void TypeTargetTransformPass::AddIoCopyInst( ...@@ -85,9 +84,10 @@ void TypeTargetTransformPass::AddIoCopyInst(
// So there will be a new Argument node and a new IoCopy Statement Node. // So there will be a new Argument node and a new IoCopy Statement Node.
CHECK(in->IsArg()); CHECK(in->IsArg());
auto node_id = [&] { return graph->nodes().size(); }; // auto node_id = [&] { return graph->nodes().size(); };
auto io_copy_output_name = auto io_copy_output_name =
string_format("%s/target_trans/%d", in->AsArg().name.c_str(), node_id()); string_format("%s/target_trans", in->AsArg().name.c_str());
// string_format("%s/target_trans/%d", in->AsArg().name.c_str(), node_id());
// TODO(MyPandaShaoxiang) should set same place with input? // TODO(MyPandaShaoxiang) should set same place with input?
auto* io_copy_output_arg = graph->NewArgumentNode(io_copy_output_name); auto* io_copy_output_arg = graph->NewArgumentNode(io_copy_output_name);
// Set the place for io_copy_output_arg node, the target should be equal to // Set the place for io_copy_output_arg node, the target should be equal to
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
// limitations under the License. // limitations under the License.
#include <vector> #include <vector>
#include "lite/backends/cuda/math/utils.h"
#include "lite/core/op_registry.h" #include "lite/core/op_registry.h"
#include "lite/core/type_system.h" #include "lite/core/type_system.h"
#include "lite/kernels/cuda/calib_compute.h" #include "lite/kernels/cuda/calib_compute.h"
...@@ -22,19 +23,13 @@ namespace lite { ...@@ -22,19 +23,13 @@ namespace lite {
namespace kernels { namespace kernels {
namespace cuda { namespace cuda {
__device__ __forceinline__ int8_t float2int8(float x) {
x = fmaxf(x, INT8_MIN);
x = fminf(x, INT8_MAX);
return __float2int_rn(x);
}
__global__ void Fp32ToInt8Kernel(const int num, __global__ void Fp32ToInt8Kernel(const int num,
const float scale, const float scale,
const float* input, const float* input,
int8_t* output) { int8_t* output) {
int index = blockIdx.x * blockDim.x + threadIdx.x; int index = blockIdx.x * blockDim.x + threadIdx.x;
if (index < num) { if (index < num) {
output[index] = float2int8(input[index] / scale); output[index] = lite::cuda::math::from_float<int8_t>(input[index] / scale);
} }
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册