未验证 提交 91dd0f0d 编写于 作者: W wenbin 提交者: GitHub

compile fix (#39272)

* slice

* shuffle pass enhancement
上级 2e6be886
...@@ -192,8 +192,7 @@ void ShuffleChannelDetectPass::ApplyImpl(ir::Graph* graph) const { ...@@ -192,8 +192,7 @@ void ShuffleChannelDetectPass::ApplyImpl(ir::Graph* graph) const {
if (reshape2_shape[1] != reshape1_shape[2] * reshape1_shape[1]) return; if (reshape2_shape[1] != reshape1_shape[2] * reshape1_shape[1]) return;
// trans on channel dim // trans on channel dim
if (trans_axis[0] != 0 || trans_axis[3] != 3 || trans_axis[4] != 4) return; if (trans_axis[0] != 0 || trans_axis[3] != 3 || trans_axis[4] != 4) return;
if (group != 1 && i_c != 1) {
if (group != 1) {
if (trans_axis[1] != 2 && trans_axis[2] != 1) { if (trans_axis[1] != 2 && trans_axis[2] != 1) {
return; return;
} }
......
...@@ -128,10 +128,14 @@ int SpecialSlicePluginDynamic::enqueue( ...@@ -128,10 +128,14 @@ int SpecialSlicePluginDynamic::enqueue(
auto input_dims = input_desc[0].dims; // (sum(S), hidden, 1, 1) auto input_dims = input_desc[0].dims; // (sum(S), hidden, 1, 1)
auto out_dims = output_desc[0].dims; // (batch, hidden, 1, 1) auto out_dims = output_desc[0].dims; // (batch, hidden, 1, 1)
assert(input_desc[0].type == nvinfer1::DataType::kHALF); PADDLE_ENFORCE_EQ(
assert(hidden % 128 == 0); input_desc[0].type, nvinfer1::DataType::kHALF,
platform::errors::InvalidArgument("Type of input should be half."));
const int32_t hidden = input_dims.d[1]; const int32_t hidden = input_dims.d[1];
PADDLE_ENFORCE_EQ(hidden % 128, 0, platform::errors::InvalidArgument(
"hidden should be multiple of 128."));
constexpr int num_threads = 128; constexpr int num_threads = 128;
const dim3 blocks(out_dims.d[0], hidden / num_threads); const dim3 blocks(out_dims.d[0], hidden / num_threads);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册