提交 1b6a2a09 编写于 作者: L lujun

fix mix input type error, test=develop

上级 b3f5876e
...@@ -88,4 +88,5 @@ REGISTER_OP_CPU_KERNEL( ...@@ -88,4 +88,5 @@ REGISTER_OP_CPU_KERNEL(
ops::LoadCombineOpKernel<paddle::platform::CPUDeviceContext, float>, ops::LoadCombineOpKernel<paddle::platform::CPUDeviceContext, float>,
ops::LoadCombineOpKernel<paddle::platform::CPUDeviceContext, double>, ops::LoadCombineOpKernel<paddle::platform::CPUDeviceContext, double>,
ops::LoadCombineOpKernel<paddle::platform::CPUDeviceContext, int>, ops::LoadCombineOpKernel<paddle::platform::CPUDeviceContext, int>,
ops::LoadCombineOpKernel<paddle::platform::CPUDeviceContext, int8_t>,
ops::LoadCombineOpKernel<paddle::platform::CPUDeviceContext, int64_t>); ops::LoadCombineOpKernel<paddle::platform::CPUDeviceContext, int64_t>);
...@@ -64,4 +64,5 @@ REGISTER_OP_CPU_KERNEL( ...@@ -64,4 +64,5 @@ REGISTER_OP_CPU_KERNEL(
load, ops::LoadOpKernel<paddle::platform::CPUDeviceContext, float>, load, ops::LoadOpKernel<paddle::platform::CPUDeviceContext, float>,
ops::LoadOpKernel<paddle::platform::CPUDeviceContext, double>, ops::LoadOpKernel<paddle::platform::CPUDeviceContext, double>,
ops::LoadOpKernel<paddle::platform::CPUDeviceContext, int>, ops::LoadOpKernel<paddle::platform::CPUDeviceContext, int>,
ops::LoadOpKernel<paddle::platform::CPUDeviceContext, int8_t>,
ops::LoadOpKernel<paddle::platform::CPUDeviceContext, int64_t>); ops::LoadOpKernel<paddle::platform::CPUDeviceContext, int64_t>);
...@@ -24,6 +24,13 @@ class SaveCombineOp : public framework::OperatorWithKernel { ...@@ -24,6 +24,13 @@ class SaveCombineOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {} void InferShape(framework::InferShapeContext *ctx) const override {}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(ctx.MultiInput<Tensor>("X")[0]->type(),
ctx.GetPlace());
}
}; };
class SaveCombineOpProtoMaker : public framework::OpProtoAndCheckerMaker { class SaveCombineOpProtoMaker : public framework::OpProtoAndCheckerMaker {
...@@ -71,4 +78,5 @@ REGISTER_OP_CPU_KERNEL( ...@@ -71,4 +78,5 @@ REGISTER_OP_CPU_KERNEL(
ops::SaveCombineOpKernel<paddle::platform::CPUDeviceContext, float>, ops::SaveCombineOpKernel<paddle::platform::CPUDeviceContext, float>,
ops::SaveCombineOpKernel<paddle::platform::CPUDeviceContext, double>, ops::SaveCombineOpKernel<paddle::platform::CPUDeviceContext, double>,
ops::SaveCombineOpKernel<paddle::platform::CPUDeviceContext, int>, ops::SaveCombineOpKernel<paddle::platform::CPUDeviceContext, int>,
ops::SaveCombineOpKernel<paddle::platform::CPUDeviceContext, int8_t>,
ops::SaveCombineOpKernel<paddle::platform::CPUDeviceContext, int64_t>); ops::SaveCombineOpKernel<paddle::platform::CPUDeviceContext, int64_t>);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册