提交 a3eb0c0b 编写于 作者: H hjchen2

Fix softmax bug

上级 5df51432
...@@ -148,8 +148,8 @@ class Tensor : public TensorBase { ...@@ -148,8 +148,8 @@ class Tensor : public TensorBase {
PADDLE_MOBILE_ENFORCE( PADDLE_MOBILE_ENFORCE(
(std::is_same<T, void>::value || (std::is_same<T, void>::value ||
holder_->type().hash_code() == typeid(T).hash_code()), holder_->type().hash_code() == typeid(T).hash_code()),
"Tensor holds the wrong type, it holds %s", "Tensor holds the wrong type, it holds %s, requested %s",
this->holder_->type().name()); this->holder_->type().name(), typeid(T).name());
return reinterpret_cast<T *>(reinterpret_cast<uintptr_t>(holder_->ptr()) + return reinterpret_cast<T *>(reinterpret_cast<uintptr_t>(holder_->ptr()) +
offset_); offset_);
...@@ -162,7 +162,7 @@ class Tensor : public TensorBase { ...@@ -162,7 +162,7 @@ class Tensor : public TensorBase {
PADDLE_MOBILE_ENFORCE( PADDLE_MOBILE_ENFORCE(
(std::is_same<T, void>::value || (std::is_same<T, void>::value ||
holder_->type().hash_code() == typeid(T).hash_code()), holder_->type().hash_code() == typeid(T).hash_code()),
"Tensor holds the wrong type, it holds %s ,requested:%s", "Tensor holds the wrong type, it holds %s, requested %s",
this->holder_->type().name(), typeid(T).name()); this->holder_->type().name(), typeid(T).name());
return reinterpret_cast<const T *>( return reinterpret_cast<const T *>(
......
...@@ -128,7 +128,7 @@ void SoftmaxFuntor<CPU, float>::operator()(const framework::Tensor *X, ...@@ -128,7 +128,7 @@ void SoftmaxFuntor<CPU, float>::operator()(const framework::Tensor *X,
x0 = vmulq_f32(x0, __inv_sum); x0 = vmulq_f32(x0, __inv_sum);
x1 = vmulq_f32(x1, __inv_sum); x1 = vmulq_f32(x1, __inv_sum);
vst1q_f32(output, x0); vst1q_f32(output, x0);
vst1q_f32(output + 4, x0); vst1q_f32(output + 4, x1);
} }
#endif #endif
for (int i = 0; i < remain; ++i) { for (int i = 0; i < remain; ++i) {
......
...@@ -28,6 +28,7 @@ void load_images(const char *image_dir, const char *images_list, ...@@ -28,6 +28,7 @@ void load_images(const char *image_dir, const char *images_list,
image_shapes->push_back(std::make_pair(height, width)); image_shapes->push_back(std::make_pair(height, width));
image_names->push_back(filename); image_names->push_back(filename);
} }
if_list.close();
} }
int main(int argc, char **argv) { int main(int argc, char **argv) {
...@@ -53,7 +54,7 @@ int main(int argc, char **argv) { ...@@ -53,7 +54,7 @@ int main(int argc, char **argv) {
for (int i = 0; i < image_names.size(); i++) { for (int i = 0; i < image_names.size(); i++) {
std::string file_name = image_names[i]; std::string file_name = image_names[i];
std::vector<float> input; std::vector<float> input_vec;
std::vector<int64_t> dims{1, 1, 48, 512}; std::vector<int64_t> dims{1, 1, 48, 512};
dims[2] = image_shapes[i].first; dims[2] = image_shapes[i].first;
dims[3] = image_shapes[i].second; dims[3] = image_shapes[i].second;
...@@ -62,14 +63,22 @@ int main(int argc, char **argv) { ...@@ -62,14 +63,22 @@ int main(int argc, char **argv) {
std::cerr << "img_path: " << img_path << std::endl; std::cerr << "img_path: " << img_path << std::endl;
std::cerr << "shape = [" << dims[0] << ", " << dims[1] << ", " << dims[2] std::cerr << "shape = [" << dims[0] << ", " << dims[1] << ", " << dims[2]
<< ", " << dims[3] << "]" << std::endl; << ", " << dims[3] << "]" << std::endl;
GetInput<float>(img_path, &input, dims); GetInput<float>(img_path, &input_vec, dims);
framework::Tensor input(input_vec, framework::make_ddim(dims));
// predict // predict
auto output = paddle_mobile.Predict(input, dims); paddle_mobile.Predict(input);
auto output_topk = paddle_mobile.Fetch("top_k_1.tmp_0");
auto output_indices = paddle_mobile.Fetch("cast_68.tmp_0");
// print result // print result
std::cerr << file_name << std::endl; std::cerr << file_name << std::endl;
std::cerr << output[0]; std::cerr << output_topk->data<float>()[0];
for (int j = 1; j < output.size(); ++j) { for (int j = 1; j < output_topk->numel(); ++j) {
std::cerr << " " << output[j]; std::cerr << " " << output_topk->data<float>()[j];
}
std::cerr << std::endl;
std::cerr << output_indices->data<float>()[0];
for (int j = 1; j < output_indices->numel(); ++j) {
std::cerr << " " << output_indices->data<float>()[j];
} }
std::cerr << std::endl; std::cerr << std::endl;
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册