提交 ae3dca77 编写于 作者: Y Yu Yang

Fix CI

上级 f1913d46
...@@ -51,7 +51,7 @@ class ScatterOp : public framework::OperatorWithKernel { ...@@ -51,7 +51,7 @@ class ScatterOp : public framework::OperatorWithKernel {
framework::DataType IndicateDataType( framework::DataType IndicateDataType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::ToDataType(ctx.Input<Tensor>("X")->type()); return framework::ToDataType(ctx.Input<Tensor>("Ref")->type());
} }
}; };
...@@ -68,7 +68,7 @@ class ScatterGradOp : public framework::OperatorWithKernel { ...@@ -68,7 +68,7 @@ class ScatterGradOp : public framework::OperatorWithKernel {
framework::DataType IndicateDataType( framework::DataType IndicateDataType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::ToDataType(ctx.Input<Tensor>("X")->type()); return framework::ToDataType(ctx.Input<Tensor>("Ref")->type());
} }
}; };
......
...@@ -158,7 +158,8 @@ class SoftmaxWithCrossEntropyOpGrad : public framework::OperatorWithKernel { ...@@ -158,7 +158,8 @@ class SoftmaxWithCrossEntropyOpGrad : public framework::OperatorWithKernel {
framework::DataType IndicateDataType( framework::DataType IndicateDataType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::ToDataType(ctx.Input<Tensor>("Logits")->type()); return framework::ToDataType(
ctx.Input<Tensor>(framework::GradVarName("Loss"))->type());
} }
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册