提交 02655a22 编写于 作者: Y Yu Yang

Return Reference Instead Pointer to GetEigenDevice

上级 bfaea910
......@@ -20,16 +20,16 @@ namespace paddle {
namespace framework {
template <>
Eigen::DefaultDevice* ExecutionContext::GetEigenDevice<
Eigen::DefaultDevice& ExecutionContext::GetEigenDevice<
platform::CPUPlace, Eigen::DefaultDevice>() const {
return device_context_.get_eigen_device<Eigen::DefaultDevice>();
return *device_context_.get_eigen_device<Eigen::DefaultDevice>();
}
#ifndef PADDLE_ONLY_CPU
template <>
Eigen::GpuDevice*
Eigen::GpuDevice&
ExecutionContext::GetEigenDevice<platform::GPUPlace, Eigen::GpuDevice>() const {
return device_context_.get_eigen_device<Eigen::GpuDevice>();
return *device_context_.get_eigen_device<Eigen::GpuDevice>();
}
#endif
......
......@@ -253,7 +253,7 @@ class ExecutionContext : public OperatorContext {
template <typename PlaceType,
typename DeviceType =
typename EigenDeviceConverter<PlaceType>::EigenDeviceType>
DeviceType* GetEigenDevice() const;
DeviceType& GetEigenDevice() const;
platform::Place GetPlace() const { return device_context_.GetPlace(); }
......
......@@ -28,8 +28,7 @@ public:
output->mutable_data<T>(context.GetPlace());
EigenVector<T>::Flatten(*output).device(
*(context.GetEigenDevice<Place>())) =
EigenVector<T>::Flatten(*output).device(context.GetEigenDevice<Place>()) =
framework::EigenVector<T>::Flatten(*input0) +
framework::EigenVector<T>::Flatten(*input1);
}
......
......@@ -27,7 +27,7 @@ public:
output->mutable_data<T>(context.GetPlace());
EigenScalar<T>::From(*output).device(*(context.GetEigenDevice<Place>())) =
EigenScalar<T>::From(*output).device(context.GetEigenDevice<Place>()) =
EigenVector<T>::Flatten(*input).mean();
}
};
......
......@@ -29,7 +29,7 @@ public:
auto output = context.Output<Tensor>(0);
output->mutable_data<T>(context.GetPlace());
EigenMatrix<T>::From(*output).device(*(context.GetEigenDevice<Place>())) =
EigenMatrix<T>::From(*output).device(context.GetEigenDevice<Place>()) =
EigenMatrix<T>::From(*context.Input<Tensor>("X"))
.contract(EigenMatrix<T>::From(*context.Input<Tensor>("Y")),
dim_pair);
......
......@@ -33,7 +33,7 @@ public:
const int rest_size = input.size() / bias_size;
Eigen::DSizes<int, 1> one_d(input.size());
Eigen::DSizes<int, 1> bcast(rest_size);
output.reshape(one_d).device(*(context.GetEigenDevice<Place>())) =
output.reshape(one_d).device(context.GetEigenDevice<Place>()) =
input.reshape(one_d) + bias.broadcast(bcast).reshape(one_d);
}
};
......
......@@ -29,7 +29,7 @@ public:
param_out->mutable_data<T>(ctx.GetPlace());
EigenVector<T>::Flatten(*param_out).device(*(ctx.GetEigenDevice<Place>())) =
EigenVector<T>::Flatten(*param_out).device(ctx.GetEigenDevice<Place>()) =
EigenVector<T>::Flatten(*param) - lr * EigenVector<T>::Flatten(*grad);
}
};
......
......@@ -27,8 +27,7 @@ public:
auto output = context.Output<Tensor>(0);
output->mutable_data<T>(context.GetPlace());
EigenVector<T>::Flatten(*output).device(
*(context.GetEigenDevice<Place>())) =
EigenVector<T>::Flatten(*output).device(context.GetEigenDevice<Place>()) =
1.0 / (1.0 + (-1.0 * EigenVector<T>::Flatten(*input)).exp());
}
};
......
......@@ -46,9 +46,9 @@ public:
.reshape(batch_by_one)
.broadcast(one_by_class));
softmax.device(*(context.GetEigenDevice<Place>())) = shifted_logits.exp();
softmax.device(context.GetEigenDevice<Place>()) = shifted_logits.exp();
softmax.device(*(context.GetEigenDevice<Place>())) =
softmax.device(context.GetEigenDevice<Place>()) =
(softmax *
softmax.sum(along_class)
.inverse()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册