提交 a3addcdc 编写于 作者: S sweetsky0901

modify for some update in trunk

上级 4d8f39b8
......@@ -210,7 +210,8 @@ set(DEPS_OPS
save_op
load_op
send_op
recv_op)
recv_op
detection_output_op)
if(WITH_DISTRIBUTE)
add_subdirectory(detail)
......@@ -233,6 +234,7 @@ op_library(cond_op SRCS cond_op.cc DEPS framework_proto tensor operator net_op)
op_library(cross_entropy_op DEPS cross_entropy)
op_library(softmax_with_cross_entropy_op DEPS cross_entropy softmax)
op_library(softmax_op DEPS softmax)
op_library(detection_output_op DEPS softmax)
op_library(sequence_softmax_op DEPS softmax)
op_library(sum_op DEPS selected_rows_functor)
op_library(sgd_op DEPS selected_rows_functor)
......
......@@ -86,5 +86,5 @@ REGISTER_OP_WITHOUT_GRADIENT(detection_output, ops::Detection_output_Op,
ops::Detection_output_OpMaker);
REGISTER_OP_CPU_KERNEL(
detection_output,
ops::Detection_output_Kernel<paddle::platform::CPUPlace, float>,
ops::Detection_output_Kernel<paddle::platform::CPUPlace, double>);
ops::Detection_output_Kernel<paddle::platform::CPUDeviceContext, float>,
ops::Detection_output_Kernel<paddle::platform::CPUDeviceContext, double>);
......@@ -15,7 +15,7 @@ limitations under the License. */
#include "paddle/operators/detection_output_op.h"
namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(
REGISTER_OP_CUDA_KERNEL(
detection_output,
ops::Detection_output_Kernel<paddle::platform::GPUPlace, float>,
ops::Detection_output_Kernel<paddle::platform::GPUPlace, double>);
ops::Detection_output_Kernel<paddle::platform::CUDADeviceContext, float>,
ops::Detection_output_Kernel<paddle::platform::CUDADeviceContext, double>);
......@@ -21,8 +21,8 @@ limitations under the License. */
#include "paddle/operators/strided_memcpy.h"
namespace paddle {
namespace operators {
template <typename Place, typename T>
inline void transpose_fun(const platform::DeviceContext& context,
template <typename DeviceContext, typename T>
inline void transpose_fun(const framework::ExecutionContext& context,
const framework::Tensor& src,
framework::Tensor* dst) {
int input_nums = src.dims()[0];
......@@ -36,17 +36,18 @@ inline void transpose_fun(const platform::DeviceContext& context,
framework::Tensor in_p_tensor_transpose;
in_p_tensor_transpose.mutable_data<T>(shape, context.GetPlace());
std::vector<int> shape_axis({0, 1, 3, 4, 2});
math::Transpose<Place, T, 5> trans5;
trans5(context, in_p_tensor, &in_p_tensor_transpose, shape_axis);
math::Transpose<DeviceContext, T, 5> trans5;
trans5(context.template device_context<DeviceContext>(), in_p_tensor,
&in_p_tensor_transpose, shape_axis);
auto dst_stride = framework::stride(dst->dims());
auto src_stride = framework::stride(in_p_tensor_transpose.dims());
StridedMemcpy<T>(context, in_p_tensor_transpose.data<T>(), src_stride,
in_p_tensor_transpose.dims(), dst_stride,
StridedMemcpy<T>(context.device_context(), in_p_tensor_transpose.data<T>(),
src_stride, in_p_tensor_transpose.dims(), dst_stride,
dst->data<T>() + offset);
offset += in_p_tensor_transpose.dims()[4] * src_stride[4];
}
}
template <typename Place, typename T>
template <typename DeviceContext, typename T>
class Detection_output_Kernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
......@@ -87,11 +88,12 @@ class Detection_output_Kernel : public framework::OpKernel<T> {
framework::Tensor conf_cpu;
framework::Tensor priorbox_cpu;
const T* priorbox_data = in_priorbox->data<T>();
transpose_fun<Place, T>(context.device_context(), *in_loc, &loc_tensor);
transpose_fun<Place, T>(context.device_context(), *in_conf, &conf_tensor);
transpose_fun<DeviceContext, T>(context, *in_loc, &loc_tensor);
transpose_fun<DeviceContext, T>(context, *in_conf, &conf_tensor);
conf_tensor.Resize(conf_shape_softmax);
math::SoftmaxFunctor<Place, T>()(context.device_context(), &conf_tensor,
&conf_tensor);
math::SoftmaxFunctor<DeviceContext, T>()(
context.template device_context<DeviceContext>(), &conf_tensor,
&conf_tensor);
T* loc_data = loc_tensor.data<T>();
T* conf_data = conf_tensor.data<T>();
if (platform::is_gpu_place(context.GetPlace())) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册