未验证 提交 ab065563 编写于 作者: Z Zhou Wei 提交者: GitHub

[CHERRY-PICK 1.8]fix bug that diag API can't use on Windows(#24825)

* cherry-pick #24762
上级 863f9e55
...@@ -24,8 +24,8 @@ namespace paddle { ...@@ -24,8 +24,8 @@ namespace paddle {
namespace operators { namespace operators {
template <typename T> template <typename T>
struct DiagFunctor { struct DiagonalFunctor {
DiagFunctor(const T* input, const int64_t* diag_stride, DiagonalFunctor(const T* input, const int64_t* diag_stride,
const int64_t* ret_strides, int64_t pos, int64_t dim_size, const int64_t* ret_strides, int64_t pos, int64_t dim_size,
T* diag) T* diag)
: input_(input), : input_(input),
...@@ -157,7 +157,7 @@ framework::Tensor Diagonal(const framework::ExecutionContext& context, ...@@ -157,7 +157,7 @@ framework::Tensor Diagonal(const framework::ExecutionContext& context,
auto& dev_ctx = context.template device_context<DeviceContext>(); auto& dev_ctx = context.template device_context<DeviceContext>();
platform::ForRange<DeviceContext> for_range(dev_ctx, diag.numel()); platform::ForRange<DeviceContext> for_range(dev_ctx, diag.numel());
DiagFunctor<T> functor(input_data, diag_arr, ret_arr, pos, dim_size, DiagonalFunctor<T> functor(input_data, diag_arr, ret_arr, pos, dim_size,
diag_data); diag_data);
for_range(functor); for_range(functor);
return diag; return diag;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册