未验证 提交 f66594a5 编写于 作者: Z Zhou Wei 提交者: GitHub

fix bug that diag API can't use on Windows(#24762)

上级 80ec2fe7
...@@ -24,8 +24,8 @@ namespace paddle { ...@@ -24,8 +24,8 @@ namespace paddle {
namespace operators { namespace operators {
template <typename T> template <typename T>
struct DiagFunctor { struct DiagonalFunctor {
DiagFunctor(const T* input, const int64_t* diag_stride, DiagonalFunctor(const T* input, const int64_t* diag_stride,
const int64_t* ret_strides, int64_t pos, int64_t dim_size, const int64_t* ret_strides, int64_t pos, int64_t dim_size,
T* diag) T* diag)
: input_(input), : input_(input),
...@@ -157,7 +157,7 @@ framework::Tensor Diagonal(const framework::ExecutionContext& context, ...@@ -157,7 +157,7 @@ framework::Tensor Diagonal(const framework::ExecutionContext& context,
auto& dev_ctx = context.template device_context<DeviceContext>(); auto& dev_ctx = context.template device_context<DeviceContext>();
platform::ForRange<DeviceContext> for_range(dev_ctx, diag.numel()); platform::ForRange<DeviceContext> for_range(dev_ctx, diag.numel());
DiagFunctor<T> functor(input_data, diag_arr, ret_arr, pos, dim_size, DiagonalFunctor<T> functor(input_data, diag_arr, ret_arr, pos, dim_size,
diag_data); diag_data);
for_range(functor); for_range(functor);
return diag; return diag;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册