未验证 提交 206a8f6c 编写于 作者: L limingshu 提交者: GitHub

code clean (#38550)

上级 9171aaa0
...@@ -22,146 +22,6 @@ namespace operators { ...@@ -22,146 +22,6 @@ namespace operators {
namespace kps = paddle::operators::kernel_primitives; namespace kps = paddle::operators::kernel_primitives;
struct DimensionsTransform {
using DimVector = std::vector<int64_t>;
typedef void (*MergeFunctor)(bool &, std::vector<DimVector> &, DimVector &,
int, int);
int64_t dim_size;
DimVector out_dims;
std::vector<DimVector> in_dims;
private:
// To compensate the lackage of input_tensors` dimension with input variable
// 'axis'
void InputDimensionsExtend(int N, int axis) {
for (auto &in_dim : in_dims) {
int64_t in_idx = 0;
if (in_dim.size() < dim_size) {
DimVector tmp_dim(dim_size, 1);
do {
if (in_dim[in_idx] == out_dims[axis] || in_dim[in_idx] == 1) {
tmp_dim[axis] = in_dim[in_idx];
in_idx++;
axis++;
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"The %d-th dimension of input tensor is expected to be equal "
"with the %d-th dimension of output tensor %d or 1, but "
"recieved %d.",
in_idx + 1, axis + 1, out_dims[axis], in_dim[in_idx]));
}
} while (in_idx < in_dim.size());
in_dim.resize(dim_size);
std::copy(tmp_dim.begin(), tmp_dim.end(), in_dim.begin());
} else {
do {
if (in_dim[in_idx] == out_dims[in_idx] || in_dim[in_idx] == 1) {
in_idx++;
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"The %d-th dimension of input tensor is expected to be equal "
"with the %d-th dimension of output tensor %d or 1, but "
"recieved %d.",
in_idx + 1, in_idx + 1, out_dims[in_idx], in_dim[in_idx]));
}
} while (in_idx < dim_size);
}
std::reverse(in_dim.begin(), in_dim.end());
}
std::reverse(out_dims.begin(), out_dims.end());
}
template <typename MergeFunctor>
__inline__ void MergeDimensions(MergeFunctor merge_func, int N) {
auto VectorReorganise = [](DimVector *vec, int l_idx, int m_idx) {
(*vec)[m_idx - 1] =
std::accumulate(vec->begin() + l_idx, vec->begin() + m_idx, 1,
std::multiplies<int64_t>());
vec->erase(vec->begin() + l_idx, vec->begin() + m_idx - 1);
};
int64_t i = 0;
while (i < dim_size) {
int cnt = 0;
int low_idx = i;
bool equal = true;
do {
merge_func(equal, in_dims, out_dims, i, N);
if (equal) {
i++;
cnt++;
} else {
break;
}
} while (i < dim_size);
if (cnt > 1) {
for (auto &in_dim : in_dims) {
VectorReorganise(&in_dim, low_idx, i);
}
VectorReorganise(&out_dims, low_idx, i);
dim_size -= --cnt;
i -= cnt;
} else if (cnt < 1) {
i++;
}
}
}
public:
explicit DimensionsTransform(
const std::vector<const framework::Tensor *> &ins,
const framework::DDim &dims, int axis) {
const int N = ins.size();
dim_size = dims.size();
out_dims = framework::vectorize<int64_t>(dims);
in_dims.resize(N);
for (int j = 0; j < N; ++j) {
in_dims[j] = framework::vectorize<int64_t>(ins[j]->dims());
}
InputDimensionsExtend(N, axis);
auto merge_sequential_dims = [](bool &equal,
std::vector<DimVector> &in_dims,
DimVector &out, int i, int num) {
for (int j = 1; j < num; ++j) {
equal &= (in_dims[0][i] == in_dims[j][i]) ? true : false;
}
};
auto merge_sequential_one_dims = [](bool &equal,
std::vector<DimVector> &in_dims,
DimVector &out, int i, int num) {
equal = in_dims[0][i] == 1;
if (equal) {
for (int j = 1; j < num; ++j) {
equal &= in_dims[j][i] == out[i];
}
}
};
// To Merge the dimensions of input_tensors while the consequtive
// equal-dimensions appears.
MergeFunctor merge_ptr = merge_sequential_dims;
MergeDimensions<MergeFunctor>(merge_ptr, N);
int min_idx = 0;
int min_val = std::accumulate(in_dims[0].begin(), in_dims[0].end(), 1,
std::multiplies<int64_t>());
for (int j = 1; j < N; ++j) {
int temp = std::accumulate(in_dims[j].begin(), in_dims[j].end(), 1,
std::multiplies<int64_t>());
min_val = min_val > temp ? temp : min_val;
min_idx = min_val == temp ? j : min_idx;
}
std::swap(in_dims[0], in_dims[min_idx]);
// To Merge the dimension of input_tensors while the consequtive
// 1-value-dimensions appears.
merge_ptr = merge_sequential_one_dims;
MergeDimensions<MergeFunctor>(merge_ptr, N);
std::swap(in_dims[min_idx], in_dims[0]);
}
};
template <ElementwiseType ET, typename InT, typename OutT, typename Functor, template <ElementwiseType ET, typename InT, typename OutT, typename Functor,
int NumOuts = 1> int NumOuts = 1>
void LaunchBroadcastElementwiseCudaKernel( void LaunchBroadcastElementwiseCudaKernel(
......
...@@ -25,12 +25,6 @@ limitations under the License. */ ...@@ -25,12 +25,6 @@ limitations under the License. */
#include "paddle/pten/include/core.h" #include "paddle/pten/include/core.h"
#include "paddle/pten/kernels/hybird/cuda/elementwise/elementwise.h" #include "paddle/pten/kernels/hybird/cuda/elementwise/elementwise.h"
#ifdef __HIPCC__
#define ELEMENTWISE_BLOCK_SIZE 256
#else
#define ELEMENTWISE_BLOCK_SIZE 512
#endif
namespace paddle { namespace paddle {
namespace operators { namespace operators {
......
...@@ -456,7 +456,7 @@ void LaunchBroadcastElementwiseCudaKernel( ...@@ -456,7 +456,7 @@ void LaunchBroadcastElementwiseCudaKernel(
ins.size(), ins.size(),
kArity)); kArity));
PADDLE_ENFORCE_LE(kArity, PADDLE_ENFORCE_LE(kArity,
ElementwiseType::kTernary, 3,
paddle::platform::errors::InvalidArgument( paddle::platform::errors::InvalidArgument(
"Currently only broadcast of ternary is supported " "Currently only broadcast of ternary is supported "
"and verified, but received %d.", "and verified, but received %d.",
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册