提交 02655a22 编写于 作者: Y Yu Yang

Return Reference Instead Pointer to GetEigenDevice

上级 bfaea910
...@@ -20,16 +20,16 @@ namespace paddle { ...@@ -20,16 +20,16 @@ namespace paddle {
namespace framework { namespace framework {
template <> template <>
Eigen::DefaultDevice* ExecutionContext::GetEigenDevice< Eigen::DefaultDevice& ExecutionContext::GetEigenDevice<
platform::CPUPlace, Eigen::DefaultDevice>() const { platform::CPUPlace, Eigen::DefaultDevice>() const {
return device_context_.get_eigen_device<Eigen::DefaultDevice>(); return *device_context_.get_eigen_device<Eigen::DefaultDevice>();
} }
#ifndef PADDLE_ONLY_CPU #ifndef PADDLE_ONLY_CPU
template <> template <>
Eigen::GpuDevice* Eigen::GpuDevice&
ExecutionContext::GetEigenDevice<platform::GPUPlace, Eigen::GpuDevice>() const { ExecutionContext::GetEigenDevice<platform::GPUPlace, Eigen::GpuDevice>() const {
return device_context_.get_eigen_device<Eigen::GpuDevice>(); return *device_context_.get_eigen_device<Eigen::GpuDevice>();
} }
#endif #endif
......
...@@ -253,7 +253,7 @@ class ExecutionContext : public OperatorContext { ...@@ -253,7 +253,7 @@ class ExecutionContext : public OperatorContext {
template <typename PlaceType, template <typename PlaceType,
typename DeviceType = typename DeviceType =
typename EigenDeviceConverter<PlaceType>::EigenDeviceType> typename EigenDeviceConverter<PlaceType>::EigenDeviceType>
DeviceType* GetEigenDevice() const; DeviceType& GetEigenDevice() const;
platform::Place GetPlace() const { return device_context_.GetPlace(); } platform::Place GetPlace() const { return device_context_.GetPlace(); }
......
...@@ -28,8 +28,7 @@ public: ...@@ -28,8 +28,7 @@ public:
output->mutable_data<T>(context.GetPlace()); output->mutable_data<T>(context.GetPlace());
EigenVector<T>::Flatten(*output).device( EigenVector<T>::Flatten(*output).device(context.GetEigenDevice<Place>()) =
*(context.GetEigenDevice<Place>())) =
framework::EigenVector<T>::Flatten(*input0) + framework::EigenVector<T>::Flatten(*input0) +
framework::EigenVector<T>::Flatten(*input1); framework::EigenVector<T>::Flatten(*input1);
} }
......
...@@ -27,7 +27,7 @@ public: ...@@ -27,7 +27,7 @@ public:
output->mutable_data<T>(context.GetPlace()); output->mutable_data<T>(context.GetPlace());
EigenScalar<T>::From(*output).device(*(context.GetEigenDevice<Place>())) = EigenScalar<T>::From(*output).device(context.GetEigenDevice<Place>()) =
EigenVector<T>::Flatten(*input).mean(); EigenVector<T>::Flatten(*input).mean();
} }
}; };
......
...@@ -29,7 +29,7 @@ public: ...@@ -29,7 +29,7 @@ public:
auto output = context.Output<Tensor>(0); auto output = context.Output<Tensor>(0);
output->mutable_data<T>(context.GetPlace()); output->mutable_data<T>(context.GetPlace());
EigenMatrix<T>::From(*output).device(*(context.GetEigenDevice<Place>())) = EigenMatrix<T>::From(*output).device(context.GetEigenDevice<Place>()) =
EigenMatrix<T>::From(*context.Input<Tensor>("X")) EigenMatrix<T>::From(*context.Input<Tensor>("X"))
.contract(EigenMatrix<T>::From(*context.Input<Tensor>("Y")), .contract(EigenMatrix<T>::From(*context.Input<Tensor>("Y")),
dim_pair); dim_pair);
......
...@@ -33,7 +33,7 @@ public: ...@@ -33,7 +33,7 @@ public:
const int rest_size = input.size() / bias_size; const int rest_size = input.size() / bias_size;
Eigen::DSizes<int, 1> one_d(input.size()); Eigen::DSizes<int, 1> one_d(input.size());
Eigen::DSizes<int, 1> bcast(rest_size); Eigen::DSizes<int, 1> bcast(rest_size);
output.reshape(one_d).device(*(context.GetEigenDevice<Place>())) = output.reshape(one_d).device(context.GetEigenDevice<Place>()) =
input.reshape(one_d) + bias.broadcast(bcast).reshape(one_d); input.reshape(one_d) + bias.broadcast(bcast).reshape(one_d);
} }
}; };
......
...@@ -29,7 +29,7 @@ public: ...@@ -29,7 +29,7 @@ public:
param_out->mutable_data<T>(ctx.GetPlace()); param_out->mutable_data<T>(ctx.GetPlace());
EigenVector<T>::Flatten(*param_out).device(*(ctx.GetEigenDevice<Place>())) = EigenVector<T>::Flatten(*param_out).device(ctx.GetEigenDevice<Place>()) =
EigenVector<T>::Flatten(*param) - lr * EigenVector<T>::Flatten(*grad); EigenVector<T>::Flatten(*param) - lr * EigenVector<T>::Flatten(*grad);
} }
}; };
......
...@@ -27,8 +27,7 @@ public: ...@@ -27,8 +27,7 @@ public:
auto output = context.Output<Tensor>(0); auto output = context.Output<Tensor>(0);
output->mutable_data<T>(context.GetPlace()); output->mutable_data<T>(context.GetPlace());
EigenVector<T>::Flatten(*output).device( EigenVector<T>::Flatten(*output).device(context.GetEigenDevice<Place>()) =
*(context.GetEigenDevice<Place>())) =
1.0 / (1.0 + (-1.0 * EigenVector<T>::Flatten(*input)).exp()); 1.0 / (1.0 + (-1.0 * EigenVector<T>::Flatten(*input)).exp());
} }
}; };
......
...@@ -46,9 +46,9 @@ public: ...@@ -46,9 +46,9 @@ public:
.reshape(batch_by_one) .reshape(batch_by_one)
.broadcast(one_by_class)); .broadcast(one_by_class));
softmax.device(*(context.GetEigenDevice<Place>())) = shifted_logits.exp(); softmax.device(context.GetEigenDevice<Place>()) = shifted_logits.exp();
softmax.device(*(context.GetEigenDevice<Place>())) = softmax.device(context.GetEigenDevice<Place>()) =
(softmax * (softmax *
softmax.sum(along_class) softmax.sum(along_class)
.inverse() .inverse()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册