提交 18aa5949 编写于 作者: L lujun

fix mix input type error, test=develop

上级 1b6a2a09
...@@ -23,14 +23,21 @@ class SaveCombineOp : public framework::OperatorWithKernel { ...@@ -23,14 +23,21 @@ class SaveCombineOp : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {} void InferShape(framework::InferShapeContext* ctx) const override {}
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(ctx.MultiInput<Tensor>("X")[0]->type(), return framework::OpKernelType(framework::proto::VarType::FP32,
ctx.GetPlace()); ctx.GetPlace());
} }
// TODO(lujun): The override here is just to bypass transform
// in operator impl, which is not elegant enough.
framework::OpKernelType GetKernelTypeForVar(
const std::string& var_name, const Tensor& tensor,
const framework::OpKernelType& expected_kernel_type) const override {
return expected_kernel_type;
}
}; };
class SaveCombineOpProtoMaker : public framework::OpProtoAndCheckerMaker { class SaveCombineOpProtoMaker : public framework::OpProtoAndCheckerMaker {
...@@ -61,7 +68,7 @@ to a file on disk. ...@@ -61,7 +68,7 @@ to a file on disk.
"(string)" "(string)"
"The \"file_path\" where the LoDTensor variables will be saved.") "The \"file_path\" where the LoDTensor variables will be saved.")
.AddCustomChecker( .AddCustomChecker(
[](const std::string &path) { return !path.empty(); }); [](const std::string& path) { return !path.empty(); });
} }
}; };
...@@ -77,6 +84,4 @@ REGISTER_OP_CPU_KERNEL( ...@@ -77,6 +84,4 @@ REGISTER_OP_CPU_KERNEL(
save_combine, save_combine,
ops::SaveCombineOpKernel<paddle::platform::CPUDeviceContext, float>, ops::SaveCombineOpKernel<paddle::platform::CPUDeviceContext, float>,
ops::SaveCombineOpKernel<paddle::platform::CPUDeviceContext, double>, ops::SaveCombineOpKernel<paddle::platform::CPUDeviceContext, double>,
ops::SaveCombineOpKernel<paddle::platform::CPUDeviceContext, int>, ops::SaveCombineOpKernel<paddle::platform::CPUDeviceContext, int>);
ops::SaveCombineOpKernel<paddle::platform::CPUDeviceContext, int8_t>,
ops::SaveCombineOpKernel<paddle::platform::CPUDeviceContext, int64_t>);
...@@ -20,6 +20,4 @@ REGISTER_OP_CUDA_KERNEL( ...@@ -20,6 +20,4 @@ REGISTER_OP_CUDA_KERNEL(
save_combine, save_combine,
ops::SaveCombineOpKernel<paddle::platform::CUDADeviceContext, float>, ops::SaveCombineOpKernel<paddle::platform::CUDADeviceContext, float>,
ops::SaveCombineOpKernel<paddle::platform::CUDADeviceContext, double>, ops::SaveCombineOpKernel<paddle::platform::CUDADeviceContext, double>,
ops::SaveCombineOpKernel<paddle::platform::CUDADeviceContext, int>, ops::SaveCombineOpKernel<paddle::platform::CUDADeviceContext, int>);
ops::SaveCombineOpKernel<paddle::platform::CUDADeviceContext, int8_t>,
ops::SaveCombineOpKernel<paddle::platform::CUDADeviceContext, int64_t>);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册