diff --git a/doc/fluid/advanced_usage/development/new_op/new_op.md b/doc/fluid/advanced_usage/development/new_op/new_op.md index 12611d0f63002816e36b17438d5c8d51fde78a84..89c14967d15535349f62a18d26191446d3b7fcfc 100644 --- a/doc/fluid/advanced_usage/development/new_op/new_op.md +++ b/doc/fluid/advanced_usage/development/new_op/new_op.md @@ -233,7 +233,105 @@ MulOp(const std::string &type, const framework::VariableNameMap &inputs, 通常`OpProtoMaker`和`Op`类的定义写在`.cc`文件中,和下面将要介绍的注册函数一起放在`.cc`中 -**注意:**通常`InferShape`操作在[编译时和运行时](https://github.com/PaddlePaddle/FluidDoc/blob/release/1.2/doc/fluid/getstarted/Developer's_Guide_to_Paddle_Fluid.md#%E8%AE%A9%E6%88%91%E4%BB%AC%E5%9C%A8fluid%E7%A8%8B%E5%BA%8F%E5%AE%9E%E4%BE%8B%E4%B8%AD%E5%8C%BA%E5%88%86%E7%BC%96%E8%AF%91%E6%97%B6%E5%92%8C%E8%BF%90%E8%A1%8C%E6%97%B6)都会被调用,在一些NLP任务可能会涉及到很多的变长操作,Paddle中在编译时变长使用-1表示,所以Op中在做检查或者推断输出变量的Shape时需要注意输入变量的某个维度是否可能为-1。 +### InferShape区分 compile time 和 run time +在我们的静态图网络中,`InferShape`操作在[编译时(compile time)和运行时(run time)](https://github.com/PaddlePaddle/FluidDoc/blob/release/1.2/doc/fluid/getstarted/Developer's_Guide_to_Paddle_Fluid.md#%E8%AE%A9%E6%88%91%E4%BB%AC%E5%9C%A8fluid%E7%A8%8B%E5%BA%8F%E5%AE%9E%E4%BE%8B%E4%B8%AD%E5%8C%BA%E5%88%86%E7%BC%96%E8%AF%91%E6%97%B6%E5%92%8C%E8%BF%90%E8%A1%8C%E6%97%B6)都会被调用,在compile time时,由于真实的维度未知,框架内部用-1来表示,在run time时,用实际的维度表示,因此维度的值在compile time和 run time时可能不一致,如果存在维度的判断和运算操作,InferShape就需要区分compile time 和 run time。 + +以下两种情况需要区分compile time和 run time。 + +**1.检查** + +如以下代码: +```cpp +auto x_dim = ctx->GetInputDim("X"); +int i = xxx; +PADDLE_ENFORCE_GT( x_dim[i] , 10) +``` + +在compile time的时候,x_dim[i]可能等于-1,导致这个PADDLE_ENFORCE_GT报错退出。 + +如果用了以下paddle中定义的宏进行判断: +```cpp +PADDLE_ENFORCE_EQ ( x_dim[i] , 10) +PADDLE_ENFORCE_NE ( x_dim[i] , 10) +PADDLE_ENFORCE_GT ( x_dim[i] , 10) +PADDLE_ENFORCE_GE ( x_dim[i] , 10) +PADDLE_ENFORCE_LT ( x_dim[i] , 10) +PADDLE_ENFORCE_LE ( x_dim[i] , 10) +``` +都需要区分compile time和run time + +**2. 运算** + +如以下代码: +```cpp +auto x_dim = ctx->GetInputDim("X"); +int i = xxx; +y_dim[0] = x_dim[i] + 10 +``` + +在compile time的时候,x_dim[i]可能等于-1,得到的 y_dim[0] 等于 9,是不符合逻辑的 + +如果用到了类似以下的运算操作 +```cpp +y_dim[i] = x_dim[i] + 10 +y_dim[i] = x_dim[i] - 10 +y_dim[i] = x_dim[i] * 10 +y_dim[i] = x_dim[i] / 10 +y_dim[i] = x_dim[i] + z_dim[i] +``` +都需要区分compile time和run time + +**处理的标准**: +- 检查: compile time的时候不判断维度等于-1的情况,但在runtime的时候检查 +- 运算: -1和其他数做任何运算都要等于-1 + +**参考代码** +1. 判断的实现方法可以参考cross_entropy_op.cc,cross_entropy_op 要求X和labels的两个输入,除了最后一维以外,其他的维度完全一致 + +```cpp + bool contain_unknown_dim = framework::contain_unknown_dim(x_dims) || + framework::contain_unknown_dim(label_dims); + bool check = ctx->IsRuntime() || !contain_unknown_dim; + if (check) { + PADDLE_ENFORCE_EQ(framework::slice_ddim(x_dims, 0, rank - 1), + framework::slice_ddim(label_dims, 0, rank - 1), + "Input(X) and Input(Label) shall have the same shape " + "except the last dimension."); + } +``` + +2. 运算的实现可以参考concat_op.cc,concat在InferShape判断时,除了进行concat轴之外,其他的维度完全一致;在生成output的维度时,把concat轴的维度求和,其他的维度和输入保持一致。 + +```cpp + auto out_dims = ins[0]; + size_t in_zero_dims_size = out_dims.size(); + for (size_t i = 1; i < n; i++) { + for (size_t j = 0; j < in_zero_dims_size; j++) { + if (j == axis) { + if (ctx->IsRuntime()) { + out_dims[axis] += ins[i][j]; + } else { + if (ins[i][j] == -1) { + out_dims[axis] = -1; + } else { + out_dims[axis] += ins[i][j]; + } + } + } else { + bool check_shape = + ctx->IsRuntime() || (out_dims[j] > 0 && ins[i][j] > 0); + if (check_shape) { + // check all shape in run time + PADDLE_ENFORCE_EQ(out_dims[j], ins[i][j], + "Input tensors should have the same " + "elements except the specify axis."); + } + } + } + } +``` + + ### 定义OpKernel类