diff --git a/paddle/operators/cross_entropy_op.cu b/paddle/operators/cross_entropy_op.cu index 04ae66de91d459a9dcadbbca7645210187b05354..5e2024e0ea9040b758e1cec4dbaa4b329bbb727e 100644 --- a/paddle/operators/cross_entropy_op.cu +++ b/paddle/operators/cross_entropy_op.cu @@ -56,7 +56,7 @@ class CrossEntropyOpCUDAKernel : public framework::OpKernel { y->mutable_data(ctx.GetPlace()); math::CrossEntropyFunctor()( - ctx, y, x, label, ctx.Attr("softLabel")); + ctx.device_context(), y, x, label, ctx.Attr("softLabel")); } }; diff --git a/paddle/operators/math/cross_entropy.cu b/paddle/operators/math/cross_entropy.cu index 2c589521c16379029affdd6fe54b432d5c63b8ea..367190e6b0682ec62550e869e2f04c3a2b2cbec3 100644 --- a/paddle/operators/math/cross_entropy.cu +++ b/paddle/operators/math/cross_entropy.cu @@ -74,7 +74,7 @@ using Tensor = framework::Tensor; template class CrossEntropyFunctor { public: - void operator()(const framework::DeviceContext& ctx, framework::Tensor* out, + void operator()(const platform::DeviceContext& ctx, framework::Tensor* out, const framework::Tensor* prob, const framework::Tensor* labels, bool softLabel) { const T* prob_data = prob->data(); diff --git a/paddle/operators/softmax_with_cross_entropy_op.cu b/paddle/operators/softmax_with_cross_entropy_op.cu index b5a7cda734019e88df1a64ab87dbf5563211a168..2bc53ecf871eb1800a920ba85e8eac31d7037efe 100644 --- a/paddle/operators/softmax_with_cross_entropy_op.cu +++ b/paddle/operators/softmax_with_cross_entropy_op.cu @@ -69,7 +69,8 @@ class SoftmaxWithCrossEntropyCUDAKernel : public framework::OpKernel { math::SoftmaxFunctor()(context.device_context(), logits, softmax); math::CrossEntropyFunctor()( - context, loss, softmax, labels, context.Attr("softLabel")); + context.device_context(), loss, softmax, labels, + context.Attr("softLabel")); } }; diff --git a/paddle/platform/device_context_test.cc b/paddle/platform/device_context_test.cc index 5883a55272f0f24c94d48bc43c62ddb7bef15465..f4b00c57dee5196e535816d8985fd7e831c4c226 100644 --- a/paddle/platform/device_context_test.cc +++ b/paddle/platform/device_context_test.cc @@ -24,7 +24,7 @@ TEST(Device, Init) { for (int i = 0; i < count; i++) { DeviceContext* device_context = new CUDADeviceContext(GPUPlace(i)); Eigen::GpuDevice* gpu_device = - device_context->template get_eigen_device(); + device_context->template GetEigenDevice(); ASSERT_NE(nullptr, gpu_device); delete device_context; }