未验证 提交 181f7cec 编写于 作者: N niuliling123 提交者: GitHub

fix a bug in nlp: text_matching/sentence_transformers when last dim is 1 and...

fix a bug in nlp: text_matching/sentence_transformers when last dim is 1 and reduce mid dim (#34941)
上级 ed6624ab
......@@ -770,7 +770,7 @@ void TensorReduceFunctorImpl(const framework::Tensor& x, framework::Tensor* y,
auto x_dim = framework::vectorize<int>(x.dims());
auto config = ReduceConfig<Ty>(origin_reduce_dims, x_dim);
config.Run(); // get the parameters of LaunchReduceKernel
int numel = x.numel();
// after config.run()
// SetOutputData for ReduceHigherDim when should_reduce_again is true,
// temp_output should be stored temp_data in output_data space or stored in
......@@ -787,7 +787,7 @@ void TensorReduceFunctorImpl(const framework::Tensor& x, framework::Tensor* y,
}
config.SetOutputData(y_data, x.place(), &tmp);
bool use_cub_reduce = (config.left_num == 1) &&
bool use_cub_reduce = (config.reduce_num == numel) &&
(!std::is_same<Tx, paddle::platform::float16>::value);
if (use_cub_reduce) {
// launch CUB::Reduce
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册