提交 b611a479 编写于 作者: Q qijun

fix gpu build error

上级 84ff7e97
......@@ -56,7 +56,7 @@ class CrossEntropyOpCUDAKernel : public framework::OpKernel<T> {
y->mutable_data<T>(ctx.GetPlace());
math::CrossEntropyFunctor<platform::GPUPlace, T>()(
ctx, y, x, label, ctx.Attr<bool>("softLabel"));
ctx.device_context(), y, x, label, ctx.Attr<bool>("softLabel"));
}
};
......
......@@ -74,7 +74,7 @@ using Tensor = framework::Tensor;
template <typename T>
class CrossEntropyFunctor<platform::GPUPlace, T> {
public:
void operator()(const framework::DeviceContext& ctx, framework::Tensor* out,
void operator()(const platform::DeviceContext& ctx, framework::Tensor* out,
const framework::Tensor* prob,
const framework::Tensor* labels, bool softLabel) {
const T* prob_data = prob->data<T>();
......
......@@ -69,7 +69,8 @@ class SoftmaxWithCrossEntropyCUDAKernel : public framework::OpKernel<T> {
math::SoftmaxFunctor<platform::GPUPlace, T>()(context.device_context(),
logits, softmax);
math::CrossEntropyFunctor<platform::GPUPlace, T>()(
context, loss, softmax, labels, context.Attr<bool>("softLabel"));
context.device_context(), loss, softmax, labels,
context.Attr<bool>("softLabel"));
}
};
......
......@@ -24,7 +24,7 @@ TEST(Device, Init) {
for (int i = 0; i < count; i++) {
DeviceContext* device_context = new CUDADeviceContext(GPUPlace(i));
Eigen::GpuDevice* gpu_device =
device_context->template get_eigen_device<Eigen::GpuDevice>();
device_context->template GetEigenDevice<GPUPlace>();
ASSERT_NE(nullptr, gpu_device);
delete device_context;
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册