未验证 提交 f66594a5 编写于 作者: Z Zhou Wei 提交者: GitHub

fix bug that diag API can't use on Windows(#24762)

上级 80ec2fe7
......@@ -24,10 +24,10 @@ namespace paddle {
namespace operators {
template <typename T>
struct DiagFunctor {
DiagFunctor(const T* input, const int64_t* diag_stride,
const int64_t* ret_strides, int64_t pos, int64_t dim_size,
T* diag)
struct DiagonalFunctor {
DiagonalFunctor(const T* input, const int64_t* diag_stride,
const int64_t* ret_strides, int64_t pos, int64_t dim_size,
T* diag)
: input_(input),
diag_stride_(diag_stride),
ret_strides_(ret_strides),
......@@ -157,8 +157,8 @@ framework::Tensor Diagonal(const framework::ExecutionContext& context,
auto& dev_ctx = context.template device_context<DeviceContext>();
platform::ForRange<DeviceContext> for_range(dev_ctx, diag.numel());
DiagFunctor<T> functor(input_data, diag_arr, ret_arr, pos, dim_size,
diag_data);
DiagonalFunctor<T> functor(input_data, diag_arr, ret_arr, pos, dim_size,
diag_data);
for_range(functor);
return diag;
} else {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册