提交 18aa5949 编写于 作者: L lujun

fix mix input type error, test=develop

上级 1b6a2a09
......@@ -23,14 +23,21 @@ class SaveCombineOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {}
void InferShape(framework::InferShapeContext* ctx) const override {}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(ctx.MultiInput<Tensor>("X")[0]->type(),
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(framework::proto::VarType::FP32,
ctx.GetPlace());
}
// TODO(lujun): The override here is just to bypass transform
// in operator impl, which is not elegant enough.
framework::OpKernelType GetKernelTypeForVar(
const std::string& var_name, const Tensor& tensor,
const framework::OpKernelType& expected_kernel_type) const override {
return expected_kernel_type;
}
};
class SaveCombineOpProtoMaker : public framework::OpProtoAndCheckerMaker {
......@@ -61,7 +68,7 @@ to a file on disk.
"(string)"
"The \"file_path\" where the LoDTensor variables will be saved.")
.AddCustomChecker(
[](const std::string &path) { return !path.empty(); });
[](const std::string& path) { return !path.empty(); });
}
};
......@@ -77,6 +84,4 @@ REGISTER_OP_CPU_KERNEL(
save_combine,
ops::SaveCombineOpKernel<paddle::platform::CPUDeviceContext, float>,
ops::SaveCombineOpKernel<paddle::platform::CPUDeviceContext, double>,
ops::SaveCombineOpKernel<paddle::platform::CPUDeviceContext, int>,
ops::SaveCombineOpKernel<paddle::platform::CPUDeviceContext, int8_t>,
ops::SaveCombineOpKernel<paddle::platform::CPUDeviceContext, int64_t>);
ops::SaveCombineOpKernel<paddle::platform::CPUDeviceContext, int>);
......@@ -20,6 +20,4 @@ REGISTER_OP_CUDA_KERNEL(
save_combine,
ops::SaveCombineOpKernel<paddle::platform::CUDADeviceContext, float>,
ops::SaveCombineOpKernel<paddle::platform::CUDADeviceContext, double>,
ops::SaveCombineOpKernel<paddle::platform::CUDADeviceContext, int>,
ops::SaveCombineOpKernel<paddle::platform::CUDADeviceContext, int8_t>,
ops::SaveCombineOpKernel<paddle::platform::CUDADeviceContext, int64_t>);
ops::SaveCombineOpKernel<paddle::platform::CUDADeviceContext, int>);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册