提交 f010edd0 编写于 作者: C chonwhite

attention works

上级 b4e07620
......@@ -67,8 +67,8 @@ class Optimizer {
"lite_conv_elementwise_fuse_pass", // conv-bn-elemwise
// TODO(Superjomn) Refine the fusion related design to select fusion
// kernels for devices automatically.
"lite_conv_activation_fuse_pass", //
"lite_fc_fuse_pass", //
"lite_conv_activation_fuse_pass", //
// "lite_fc_fuse_pass", //
"lite_shuffle_channel_fuse_pass", //
"lite_transpose_softmax_transpose_fuse_pass", //
"lite_interpolate_fuse_pass", //
......
......@@ -80,32 +80,37 @@ class FillConstantBatchLikeCompute
auto& param = *param_.get_mutable<param_t>();
auto& context = ctx_->As<ARMContext>();
if (param.input->lod().size() && param.input_dim_idx == 0) {
auto odims = param.out->dims();
odims[param.output_dim_idx] = param.input->lod().back().size() - 1;
param.out->Resize(odims);
auto data = param.out->template mutable_data<T>();
for (int i = 0; i < param.out->numel(); i++) {
data[i] = param.value;
}
if (param.dtype == static_cast<int32_t>(lite::core::FluidType::FP32)) {
auto data = param.out->template mutable_data<float>();
for (int i = 0; i < param.out->numel(); i++) {
data[i] = param.value;
}
} else if (param.dtype ==
static_cast<int32_t>(lite::core::FluidType::INT32)) {
auto data = param.out->template mutable_data<int32_t>();
for (int i = 0; i < param.out->numel(); i++) {
data[i] = param.value;
}
} else if (param.dtype ==
static_cast<int32_t>(lite::core::FluidType::INT8)) {
auto data = param.out->template mutable_data<int8_t>();
for (int i = 0; i < param.out->numel(); i++) {
data[i] = param.value;
}
} else {
LOG(FATAL) << "not supported dtype " << param.dtype;
}
// if (param.input->lod().size() && param.input_dim_idx == 0) {
// auto odims = param.out->dims();
// odims[param.output_dim_idx] = param.input->lod().back().size() - 1;
// param.out->Resize(odims);
// }
// if (param.dtype == static_cast<int32_t>(lite::core::FluidType::FP32)) {
// auto data = param.out->template mutable_data<float>();
// for (int i = 0; i < param.out->numel(); i++) {
// data[i] = param.value;
// }
// } else if (param.dtype ==
// static_cast<int32_t>(lite::core::FluidType::INT32)) {
// auto data = param.out->template mutable_data<int32_t>();
// for (int i = 0; i < param.out->numel(); i++) {
// data[i] = param.value;
// }
// } else if (param.dtype ==
// static_cast<int32_t>(lite::core::FluidType::INT8)) {
// auto data = param.out->template mutable_data<int8_t>();
// for (int i = 0; i < param.out->numel(); i++) {
// data[i] = param.value;
// }
// } else {
// LOG(FATAL) << "not supported dtype " << param.dtype;
// }
}
virtual ~FillConstantBatchLikeCompute() = default;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册