未验证 提交 181f7cec 编写于 作者: N niuliling123 提交者: GitHub

fix a bug in nlp: text_matching/sentence_transformers when last dim is 1 and...

fix a bug in nlp: text_matching/sentence_transformers when last dim is 1 and reduce mid dim (#34941)
上级 ed6624ab
...@@ -770,7 +770,7 @@ void TensorReduceFunctorImpl(const framework::Tensor& x, framework::Tensor* y, ...@@ -770,7 +770,7 @@ void TensorReduceFunctorImpl(const framework::Tensor& x, framework::Tensor* y,
auto x_dim = framework::vectorize<int>(x.dims()); auto x_dim = framework::vectorize<int>(x.dims());
auto config = ReduceConfig<Ty>(origin_reduce_dims, x_dim); auto config = ReduceConfig<Ty>(origin_reduce_dims, x_dim);
config.Run(); // get the parameters of LaunchReduceKernel config.Run(); // get the parameters of LaunchReduceKernel
int numel = x.numel();
// after config.run() // after config.run()
// SetOutputData for ReduceHigherDim when should_reduce_again is true, // SetOutputData for ReduceHigherDim when should_reduce_again is true,
// temp_output should be stored temp_data in output_data space or stored in // temp_output should be stored temp_data in output_data space or stored in
...@@ -787,7 +787,7 @@ void TensorReduceFunctorImpl(const framework::Tensor& x, framework::Tensor* y, ...@@ -787,7 +787,7 @@ void TensorReduceFunctorImpl(const framework::Tensor& x, framework::Tensor* y,
} }
config.SetOutputData(y_data, x.place(), &tmp); config.SetOutputData(y_data, x.place(), &tmp);
bool use_cub_reduce = (config.left_num == 1) && bool use_cub_reduce = (config.reduce_num == numel) &&
(!std::is_same<Tx, paddle::platform::float16>::value); (!std::is_same<Tx, paddle::platform::float16>::value);
if (use_cub_reduce) { if (use_cub_reduce) {
// launch CUB::Reduce // launch CUB::Reduce
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册