未验证 提交 4561fc37 编写于 作者: W wangchaochaohu 提交者: GitHub

Add check point for gather Op (#26696)

上级 eb097d64
...@@ -17,7 +17,7 @@ limitations under the License. */ ...@@ -17,7 +17,7 @@ limitations under the License. */
#include <string> #include <string>
#include <vector> #include <vector>
#include "paddle/fluid/framework/ddim.h" #include "paddle/fluid/framework/ddim.h"
#include "paddle/fluid/framework/op_version_registry.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -152,3 +152,7 @@ REGISTER_OP_CPU_KERNEL(gather_grad, ops::GatherGradientOpKernel<float>, ...@@ -152,3 +152,7 @@ REGISTER_OP_CPU_KERNEL(gather_grad, ops::GatherGradientOpKernel<float>,
ops::GatherGradientOpKernel<int>, ops::GatherGradientOpKernel<int>,
ops::GatherGradientOpKernel<uint8_t>, ops::GatherGradientOpKernel<uint8_t>,
ops::GatherGradientOpKernel<int64_t>); ops::GatherGradientOpKernel<int64_t>);
REGISTER_OP_VERSION(gather)
.AddCheckpoint(R"ROC(upgrad gather, add attribut [axis])ROC",
paddle::framework::compatible::OpVersionDesc().NewAttr(
"axis", "Specify the axis of gather operation.", {}));
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册