提交 f010edd0 编写于 作者: C chonwhite

attention works

上级 b4e07620
...@@ -68,7 +68,7 @@ class Optimizer { ...@@ -68,7 +68,7 @@ class Optimizer {
// TODO(Superjomn) Refine the fusion related design to select fusion // TODO(Superjomn) Refine the fusion related design to select fusion
// kernels for devices automatically. // kernels for devices automatically.
"lite_conv_activation_fuse_pass", // "lite_conv_activation_fuse_pass", //
"lite_fc_fuse_pass", // // "lite_fc_fuse_pass", //
"lite_shuffle_channel_fuse_pass", // "lite_shuffle_channel_fuse_pass", //
"lite_transpose_softmax_transpose_fuse_pass", // "lite_transpose_softmax_transpose_fuse_pass", //
"lite_interpolate_fuse_pass", // "lite_interpolate_fuse_pass", //
......
...@@ -80,32 +80,37 @@ class FillConstantBatchLikeCompute ...@@ -80,32 +80,37 @@ class FillConstantBatchLikeCompute
auto& param = *param_.get_mutable<param_t>(); auto& param = *param_.get_mutable<param_t>();
auto& context = ctx_->As<ARMContext>(); auto& context = ctx_->As<ARMContext>();
if (param.input->lod().size() && param.input_dim_idx == 0) { auto data = param.out->template mutable_data<T>();
auto odims = param.out->dims();
odims[param.output_dim_idx] = param.input->lod().back().size() - 1;
param.out->Resize(odims);
}
if (param.dtype == static_cast<int32_t>(lite::core::FluidType::FP32)) {
auto data = param.out->template mutable_data<float>();
for (int i = 0; i < param.out->numel(); i++) {
data[i] = param.value;
}
} else if (param.dtype ==
static_cast<int32_t>(lite::core::FluidType::INT32)) {
auto data = param.out->template mutable_data<int32_t>();
for (int i = 0; i < param.out->numel(); i++) { for (int i = 0; i < param.out->numel(); i++) {
data[i] = param.value; data[i] = param.value;
} }
} else if (param.dtype ==
static_cast<int32_t>(lite::core::FluidType::INT8)) { // if (param.input->lod().size() && param.input_dim_idx == 0) {
auto data = param.out->template mutable_data<int8_t>(); // auto odims = param.out->dims();
for (int i = 0; i < param.out->numel(); i++) { // odims[param.output_dim_idx] = param.input->lod().back().size() - 1;
data[i] = param.value; // param.out->Resize(odims);
} // }
} else {
LOG(FATAL) << "not supported dtype " << param.dtype; // if (param.dtype == static_cast<int32_t>(lite::core::FluidType::FP32)) {
} // auto data = param.out->template mutable_data<float>();
// for (int i = 0; i < param.out->numel(); i++) {
// data[i] = param.value;
// }
// } else if (param.dtype ==
// static_cast<int32_t>(lite::core::FluidType::INT32)) {
// auto data = param.out->template mutable_data<int32_t>();
// for (int i = 0; i < param.out->numel(); i++) {
// data[i] = param.value;
// }
// } else if (param.dtype ==
// static_cast<int32_t>(lite::core::FluidType::INT8)) {
// auto data = param.out->template mutable_data<int8_t>();
// for (int i = 0; i < param.out->numel(); i++) {
// data[i] = param.value;
// }
// } else {
// LOG(FATAL) << "not supported dtype " << param.dtype;
// }
} }
virtual ~FillConstantBatchLikeCompute() = default; virtual ~FillConstantBatchLikeCompute() = default;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册