Created by: wawltor
PR types
Bug fixes
PR changes
Others
Describe
fix the argmin,argmax op for the paddlepaddle 2.0
-
Registered the
AddCheckPoint
for the argmax, argmin, add theflatten
attribuate for the argmax/argmin, the flatten means that flattening the tensor as input;At the same time, change the default value from -1 to 3 in the attribute of argmax/argmin, but do not change the operator logic. -
Fixed the output of argmax/argmin, because the paddlepaddle do not support the
scalar
tensor, change the output of argmin/argmax from thescalar
tensor to normal tensor with shape[1]
. -
Added the check for dtype of argmin/argmax, when the num of element in argmax/argmin is larger the max value of
int32
, remind the user to use the dtype ofint64
.
测试改变Argmax, Argmin的属性后是否能进行正常预测且预测值是否为默认dtype
1 . 测试代码一 在1.8版本进行模型的搭建,同时argmax/argmin使用默认dtype属性来保存模型,默认的返回的结果的dtype为int64, 测试代码如下
import paddle
import numpy as np
import paddle.fluid as F
import paddle.fluid.layers as L
np.random.seed(123)
input_var = F.data(name="input_data", shape=[10, 100], dtype="float32")
max_value = L.argmax(input_var, axis=0)
loss = L.reduce_sum(max_value)
exe = F.Executor(F.CUDAPlace(0))
exe.run(F.default_startup_program())
numpy_data = np.random.rand(10, 100).astype("float32")
for i in range(0, 1):
result, indices = exe.run(feed={"input_data": numpy_data}, fetch_list=[loss, max_value])
print("the result is {}, and the argmax output dtype:{}".format(result, indices.dtype))
F.io.save_inference_model(dirname="./model_path", feeded_var_names=['input_data'], target_vars=[loss, max_value], executor=exe)
print("save the inference model done in paddle version:{}".format(paddle.__version__))
运行且保存模型如下: 2 . 使用paddle-develop版本进行模型的预测,测试代码如下
mport numpy as np
import paddle
import paddle.fluid as F
import paddle.fluid.layers as L
exe = F.Executor(F.CUDAPlace(0))
exe.run(F.default_startup_program())
[inference_program, feed_target_names, fetch_targets] = (F.io.load_inference_model(dirname="./model_path", executor=exe))
print("laod inference model from paddle version:{}".format(paddle.__version__))
np.random.seed(123)
input_np = np.array(np.random.random((10, 100)), dtype=np.float32)
results, indices = exe.run(inference_program,
feed={feed_target_names[0]: input_np},
fetch_list=fetch_targets)
print("the result is {}, indices dtype:{}".format(results, indices.dtype))
测试结果正常,跟1.8.4版本的输出结果一致,且argmax/argmin的输出的indices是int64
3 . 测试结论 无兼容性影响