提交 dcc2e62c 编写于 作者: R Reed Wanderman-Milne 提交者: TensorFlower Gardener

Fix Windows GPU build failure in resize_blinear_op.cc.

I broke it in 67d15573.

Before, I called a function under a `std::is_same<Device, CPUDevice>` condition which cannot be linked if Device is a GPUDevice. I would expect the function not to be generated if Device is a GPUDevice due to dead code elimination, but apparently it still is on Windows.

PiperOrigin-RevId: 327887418
Change-Id: Ib97e1abf1680c75dc072850cc69c761e10ac3e1e
上级 d4e7fede
......@@ -286,6 +286,25 @@ void resize_image(typename TTypes<T, 4>::ConstTensor images,
}
}
template <typename Device>
struct CastFloatToHalf {
void operator()(const Device& d, typename TTypes<float>::ConstFlat input,
typename TTypes<Eigen::half>::Flat output) {
output.device(d) = input.template cast<Eigen::half>();
}
};
template <>
struct CastFloatToHalf<GPUDevice> {
void operator()(const GPUDevice& d, typename TTypes<float>::ConstFlat input,
typename TTypes<Eigen::half>::Flat output) {
// Use existing cast functor instead of directly casting Eigen tensor, as
// otherwise we need to instantiate the cast function in a .cu.cc file
functor::CastFunctor<GPUDevice, Eigen::half, float> cast;
cast(d, output, input);
}
};
} // namespace
// Partial specialization of ResizeBilinear functor for a CPUDevice.
......@@ -378,19 +397,10 @@ class ResizeBilinearOpGrad : public OpKernel {
functor::ResizeBilinearGrad<Device, float>()(
context->eigen_device<Device>(), input_grad, st.height_scale,
st.width_scale, half_pixel_centers_, output_grad.tensor<float, 4>());
if (std::is_same<Device, CPUDevice>::value) {
const Device& d = context->template eigen_device<Device>();
st.output->template flat<Eigen::half>().device(d) =
output_grad.template flat<float>().template cast<Eigen::half>();
} else {
// Use cast functor instead of directly casting Eigen tensor, as
// otherwise we need to instantiate the cast function in a .cu.cc file
const Tensor& output_grad_const = output_grad;
functor::CastFunctor<Device, Eigen::half, float> cast;
const Device& device = context->template eigen_device<Device>();
cast(device, st.output->template flat<Eigen::half>(),
output_grad_const.template flat<float>());
}
const Tensor& output_grad_const = output_grad;
CastFloatToHalf<Device>{}(context->template eigen_device<Device>(),
output_grad_const.template flat<float>(),
st.output->template flat<Eigen::half>());
}
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册