未验证 提交 8fc05e03 编写于 作者: W Wu Yi 提交者: GitHub

fix cpu build test=develop (#14260)

上级 4dbc0184
......@@ -26,7 +26,7 @@ class RefByTrainerIdKernel : public framework::OpKernel<T> {
auto* out = context.Output<framework::Tensor>("Out");
auto in_list = context.MultiInput<framework::Tensor>("X");
auto* trainer_id_t = context.Input<framework::Tensor>("TrainerId");
int64_t trainer_id;
int64_t trainer_id = 0;
auto* trainer_id_data = trainer_id_t->data<int64_t>();
if (platform::is_gpu_place(context.GetPlace())) {
#ifdef PADDLE_WITH_CUDA
......@@ -38,7 +38,6 @@ class RefByTrainerIdKernel : public framework::OpKernel<T> {
} else {
trainer_id = *trainer_id_data;
}
printf("after get trainer_id %lu\n", trainer_id);
PADDLE_ENFORCE_LT(trainer_id, in_list.size());
out->mutable_data<T>(context.GetPlace());
out->ShareDataWith(*(in_list[trainer_id]));
......
......@@ -1588,7 +1588,6 @@ to transpile() call.")
ref_inputs = []
for p, p_bak in self.param_bak_list:
if p.name == param_var.name:
print("#### ref inputs: ", param_var.name, p_bak.name)
ref_inputs.append(p_bak)
block.append_op(
type="ref_by_trainer_id",
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册