Skip to content

  • 体验新版
    • 正在加载...
  • 登录
  • PaddlePaddle
  • Paddle
  • 合并请求
  • !26792

P
Paddle
  • 项目概览

PaddlePaddle / Paddle
大约 2 年 前同步成功

通知 2325
Star 20933
Fork 5424
  • 代码
    • 文件
    • 提交
    • 分支
    • Tags
    • 贡献者
    • 分支图
    • Diff
  • Issue 1423
    • 列表
    • 看板
    • 标记
    • 里程碑
  • 合并请求 543
  • Wiki 0
    • Wiki
  • 分析
    • 仓库
    • DevOps
  • 项目成员
  • Pages
P
Paddle
  • 项目概览
    • 项目概览
    • 详情
    • 发布
  • 仓库
    • 仓库
    • 文件
    • 提交
    • 分支
    • 标签
    • 贡献者
    • 分支图
    • 比较
  • Issue 1,423
    • Issue 1,423
    • 列表
    • 看板
    • 标记
    • 里程碑
  • 合并请求 543
    • 合并请求 543
  • Pages
  • 分析
    • 分析
    • 仓库分析
    • DevOps
  • Wiki 0
    • Wiki
  • 成员
    • 成员
  • 收起侧边栏
  • 动态
  • 分支图
  • 创建新Issue
  • 提交
  • Issue看板

fix the argmin,argmax op for the paddlepaddle 2.0 !26792

  • Report abuse
!26792 已合并 8月 29, 2020 由 saxon_zh@saxon_zh 创建
#<User:0x00007ff7b5b35fb0>
  • 概览 2
  • 提交 5
  • 变更 5

Created by: wawltor

PR types

Bug fixes

PR changes

Others

Describe

fix the argmin,argmax op for the paddlepaddle 2.0

  • Registered the AddCheckPoint for the argmax, argmin, add the flatten attribuate for the argmax/argmin, the flatten means that flattening the tensor as input;At the same time, change the default value from -1 to 3 in the attribute of argmax/argmin, but do not change the operator logic.

  • Fixed the output of argmax/argmin, because the paddlepaddle do not support the scalar tensor, change the output of argmin/argmax from the scalar tensor to normal tensor with shape [1] .

  • Added the check for dtype of argmin/argmax, when the num of element in argmax/argmin is larger the max value of int32, remind the user to use the dtype of int64.

测试改变Argmax, Argmin的属性后是否能进行正常预测且预测值是否为默认dtype

1 . 测试代码一 在1.8版本进行模型的搭建,同时argmax/argmin使用默认dtype属性来保存模型,默认的返回的结果的dtype为int64, 测试代码如下

import paddle
import numpy as np
import paddle.fluid as F
import paddle.fluid.layers as L
np.random.seed(123)
input_var = F.data(name="input_data", shape=[10, 100], dtype="float32")
max_value = L.argmax(input_var, axis=0)
loss = L.reduce_sum(max_value)
exe = F.Executor(F.CUDAPlace(0))
exe.run(F.default_startup_program())


numpy_data = np.random.rand(10, 100).astype("float32")
for i in range(0, 1):
   result, indices = exe.run(feed={"input_data": numpy_data}, fetch_list=[loss, max_value])
print("the result is {}, and the argmax output dtype:{}".format(result, indices.dtype))

F.io.save_inference_model(dirname="./model_path", feeded_var_names=['input_data'], target_vars=[loss, max_value], executor=exe)
print("save the inference model done in paddle version:{}".format(paddle.__version__))

运行且保存模型如下: 图片 2 . 使用paddle-develop版本进行模型的预测,测试代码如下

mport numpy as np
import paddle
import paddle.fluid as F
import paddle.fluid.layers as L

exe = F.Executor(F.CUDAPlace(0))
exe.run(F.default_startup_program())
[inference_program, feed_target_names, fetch_targets] = (F.io.load_inference_model(dirname="./model_path", executor=exe))
print("laod inference model from paddle version:{}".format(paddle.__version__))
np.random.seed(123)
input_np = np.array(np.random.random((10, 100)), dtype=np.float32)
results, indices = exe.run(inference_program,
        feed={feed_target_names[0]: input_np},
        fetch_list=fetch_targets)
print("the result is {}, indices dtype:{}".format(results, indices.dtype))

图片 测试结果正常,跟1.8.4版本的输出结果一致,且argmax/argmin的输出的indices是int64

3 . 测试结论 无兼容性影响

指派人
分配到
审核者
Request review from
无
里程碑
无
分配里程碑
工时统计
标识: paddlepaddle/Paddle!26792
Source branch: github/fork/wawltor/fix_argmin_argmax_keepdims
渝ICP备2023009037号

京公网安备11010502055752号

网络110报警服务 Powered by GitLab CE v13.7
开源知识
Git 入门 Pro Git 电子书 在线学 Git
Markdown 基础入门 IT 技术知识开源图谱
帮助
使用手册 反馈建议 博客
《GitCode 隐私声明》 《GitCode 服务条款》 关于GitCode
Powered by GitLab CE v13.7