未验证 提交 b2122239 编写于 作者: R ronnywang 提交者: GitHub

[CustomDevice] register load_combine op (#45980)

上级 c7b373f2
...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and ...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/custom_device_common_op_registry.h" #include "paddle/fluid/operators/custom_device_common_op_registry.h"
#include "paddle/fluid/operators/load_combine_op.h"
#include "paddle/fluid/operators/run_program_op.h" #include "paddle/fluid/operators/run_program_op.h"
#include "paddle/fluid/operators/save_combine_op.h" #include "paddle/fluid/operators/save_combine_op.h"
#include "paddle/phi/backends/device_manager.h" #include "paddle/phi/backends/device_manager.h"
...@@ -52,6 +53,19 @@ void RegisterCustomDeviceCommonKernel(const std::string& dev_type) { ...@@ -52,6 +53,19 @@ void RegisterCustomDeviceCommonKernel(const std::string& dev_type) {
SaveCombineOpKernel<paddle::platform::CustomDeviceContext, int>, SaveCombineOpKernel<paddle::platform::CustomDeviceContext, int>,
paddle::operators :: paddle::operators ::
SaveCombineOpKernel<paddle::platform::CustomDeviceContext, int64_t>); SaveCombineOpKernel<paddle::platform::CustomDeviceContext, int64_t>);
REGISTER_OP_CUSTOM_DEVICE_KERNEL(
load_combine,
device_type,
paddle::operators::
LoadCombineOpKernel<paddle::platform::CustomDeviceContext, float>,
paddle::operators::
LoadCombineOpKernel<paddle::platform::CustomDeviceContext, double>,
paddle::operators::
LoadCombineOpKernel<paddle::platform::CustomDeviceContext, int>,
paddle::operators::
LoadCombineOpKernel<paddle::platform::CustomDeviceContext, int8_t>,
paddle::operators::
LoadCombineOpKernel<paddle::platform::CustomDeviceContext, int64_t>);
} }
} // namespace operators } // namespace operators
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册