diff --git a/paddle/operators/scatter_op.cc b/paddle/operators/scatter_op.cc index 619acfc8b62ed8e50dd60d110280d981f5c1e3ab..cadd8841b6ab3a3674054240265eb6d4b474db1e 100644 --- a/paddle/operators/scatter_op.cc +++ b/paddle/operators/scatter_op.cc @@ -51,7 +51,7 @@ class ScatterOp : public framework::OperatorWithKernel { framework::DataType IndicateDataType( const framework::ExecutionContext& ctx) const override { - return framework::ToDataType(ctx.Input("X")->type()); + return framework::ToDataType(ctx.Input("Ref")->type()); } }; @@ -68,7 +68,7 @@ class ScatterGradOp : public framework::OperatorWithKernel { framework::DataType IndicateDataType( const framework::ExecutionContext& ctx) const override { - return framework::ToDataType(ctx.Input("X")->type()); + return framework::ToDataType(ctx.Input("Ref")->type()); } }; diff --git a/paddle/operators/softmax_with_cross_entropy_op.cc b/paddle/operators/softmax_with_cross_entropy_op.cc index de7c532421c48927c77309db4fc713f9b5209967..a76489871f30dc8d852b6a783efeff41704fd4a4 100644 --- a/paddle/operators/softmax_with_cross_entropy_op.cc +++ b/paddle/operators/softmax_with_cross_entropy_op.cc @@ -158,7 +158,8 @@ class SoftmaxWithCrossEntropyOpGrad : public framework::OperatorWithKernel { framework::DataType IndicateDataType( const framework::ExecutionContext& ctx) const override { - return framework::ToDataType(ctx.Input("Logits")->type()); + return framework::ToDataType( + ctx.Input(framework::GradVarName("Loss"))->type()); } };