提交 0412f5e0 编写于 作者: D dzhwinter

"fix ci"

上级 0be1e09f
......@@ -13,7 +13,6 @@ See the License for the specific language governing permissions and
limitations under the License. */
#define EIGEN_USE_GPU
#include <stdio.h>
#include <algorithm>
#include "paddle/fluid/operators/sequence_expand_op.h"
#include "paddle/fluid/platform/cuda_helper.h"
......@@ -109,12 +108,10 @@ struct SequenceExpandFunctor<platform::CUDADeviceContext, T> {
const framework::Vector<size_t>& x_lod, /*expand source lod*/
const framework::Vector<size_t>& ref_lod, /*expand referenced lod*/
LoDTensor* out) {
int x_item_length = 1;
x_item_length = x.numel() / x.dims()[0];
VLOG(0) << "x_item_length" << x_item_length;
int thread_x = std::max(static_cast<int>(ref_lod.size()), 32);
int thread_y = std::max(1024 / thread_x, 16);
int thread_z = std::min(1024 / thread_x / thread_y, 16);
int x_item_length = x.numel() / x.dims()[0];
int thread_x = std::min(32, std::max(static_cast<int>(ref_lod.size()), 16));
int thread_y = 16;
int thread_z = 1024 / thread_x / thread_y;
int block_x = static_cast<int>(ref_lod.size());
dim3 block_size(thread_x, thread_y, thread_z);
dim3 grid_size(block_x, 1);
......@@ -133,12 +130,10 @@ struct SequenceExpandGradFunctor<platform::CUDADeviceContext, T> {
const framework::Vector<size_t>& x_lod, /*expand source lod*/
const framework::Vector<size_t>& ref_lod, /*expand based lod*/
LoDTensor* dx) {
int x_item_length = 1;
x_item_length = framework::product(dx->dims()) / dx->dims()[0];
int thread_x = std::max(static_cast<int>(ref_lod.size()), 32);
int thread_y = std::max(1024 / thread_x, 16);
int thread_z = std::min(1024 / thread_x / thread_y, 16);
int x_item_length = framework::product(dx->dims()) / dx->dims()[0];
int thread_x = std::min(32, std::max(static_cast<int>(ref_lod.size()), 16));
int thread_y = 16;
int thread_z = 1024 / thread_x / thread_y;
int block_x = static_cast<int>(ref_lod.size());
dim3 block_size(thread_x, thread_y, thread_z);
dim3 grid_size(block_x, 1);
......
......@@ -15,8 +15,6 @@ limitations under the License. */
#pragma once
#include <numeric> // std::iota
#include <glog/logging.h>
#include <sstream>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/memory/memcpy.h"
#include "paddle/fluid/operators/math/math_function.h"
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册