未验证 提交 f66594a5 编写于 作者: Z Zhou Wei 提交者: GitHub

fix bug that diag API can't use on Windows(#24762)

上级 80ec2fe7
......@@ -24,8 +24,8 @@ namespace paddle {
namespace operators {
template <typename T>
struct DiagFunctor {
DiagFunctor(const T* input, const int64_t* diag_stride,
struct DiagonalFunctor {
DiagonalFunctor(const T* input, const int64_t* diag_stride,
const int64_t* ret_strides, int64_t pos, int64_t dim_size,
T* diag)
: input_(input),
......@@ -157,7 +157,7 @@ framework::Tensor Diagonal(const framework::ExecutionContext& context,
auto& dev_ctx = context.template device_context<DeviceContext>();
platform::ForRange<DeviceContext> for_range(dev_ctx, diag.numel());
DiagFunctor<T> functor(input_data, diag_arr, ret_arr, pos, dim_size,
DiagonalFunctor<T> functor(input_data, diag_arr, ret_arr, pos, dim_size,
diag_data);
for_range(functor);
return diag;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册