提交 20654cf7 编写于 作者: S sweetsky0901

modify for type check rewrite

上级 27cf7f33
......@@ -19,8 +19,8 @@ namespace operators {
namespace math {
// All tensors are in NCHW format
template <typename T>
class Unpool2dMaxFunctor<platform::CPUPlace, T> {
template <typename T, typename T2>
class Unpool2dMaxFunctor<platform::CPUPlace, T, T2> {
public:
void operator()(const platform::DeviceContext& context,
const framework::Tensor& input,
......@@ -35,7 +35,7 @@ class Unpool2dMaxFunctor<platform::CPUPlace, T> {
int input_feasize = input_height * input_width;
int output_feasize = output_height * output_width;
const T* input_data = input.data<T>();
const T * indices_data = indices.data<T>();
const T2 * indices_data = indices.data<T2>();
T* output_data = output->mutable_data<T>(context.GetPlace());
for (int b = 0; b < batch_size; ++b) {
for (int c = 0; c < output_channels; ++c) {
......@@ -54,8 +54,8 @@ class Unpool2dMaxFunctor<platform::CPUPlace, T> {
template <class T>
class Unpool2dMaxGradFunctor<platform::CPUPlace, T> {
template <class T, typename T2>
class Unpool2dMaxGradFunctor<platform::CPUPlace, T, T2> {
public:
void operator()(const platform::DeviceContext& context,
const framework::Tensor& input,
......@@ -71,7 +71,7 @@ public:
const int output_width = output.dims()[3];
int input_feasize = input_height * input_width;
int output_feasize = output_height * output_width;
const T* indices_data = indices.data<T>();
const T2 * indices_data = indices.data<T2>();
const T* output_grad_data = output_grad.data<T>();
T* input_grad_data = input_grad->mutable_data<T>(context.GetPlace());
......@@ -90,10 +90,10 @@ public:
}
};
template class Unpool2dMaxGradFunctor<platform::CPUPlace, float>;
template class Unpool2dMaxGradFunctor<platform::CPUPlace, double>;
template class Unpool2dMaxFunctor<platform::CPUPlace, float>;
template class Unpool2dMaxFunctor<platform::CPUPlace, double>;
template class Unpool2dMaxGradFunctor<platform::CPUPlace, float, int>;
template class Unpool2dMaxGradFunctor<platform::CPUPlace, double, int>;
template class Unpool2dMaxFunctor<platform::CPUPlace, float, int>;
template class Unpool2dMaxFunctor<platform::CPUPlace, double, int>;
} // namespace math
} // namespace operators
......
......@@ -19,10 +19,10 @@ namespace paddle {
namespace operators {
namespace math {
template <typename T>
template <typename T, typename T2>
__global__ void KernelUnpool2dMax(const int nthreads,
const T* input_data,
const T* indices_data,
const T2 * indices_data,
const int input_height,
const int input_width,
const int channels,
......@@ -45,10 +45,10 @@ __global__ void KernelUnpool2dMax(const int nthreads,
output_data[out_offset + out_index] = input_data[i];
}
}
template <typename T>
template <typename T, typename T2>
__global__ void KernelUnpool2dMaxGrad(const int nthreads,
const T* input_data,
const T* indices_data,
const T2* indices_data,
const int input_height,
const int input_width,
const int channels,
......@@ -76,8 +76,8 @@ __global__ void KernelUnpool2dMaxGrad(const int nthreads,
/*
* All tensors are in NCHW format.
*/
template <typename T>
class Unpool2dMaxFunctor<platform::GPUPlace, T> {
template <typename T, typename T2>
class Unpool2dMaxFunctor<platform::GPUPlace, T, T2> {
public:
void operator()(const platform::DeviceContext& context,
const framework::Tensor& input,
......@@ -90,7 +90,7 @@ class Unpool2dMaxFunctor<platform::GPUPlace, T> {
const int output_height = output->dims()[2];
const int output_width = output->dims()[3];
const T* input_data = input.data<T>();
const T* indices_data = indices.data<T>();
const T2 * indices_data = indices.data<T2>();
T* output_data = output->mutable_data<T>(context.GetPlace());
int nthreads = batch_size * output_channels * input_height * input_width;
int blocks = (nthreads + 1024 - 1) / 1024;
......@@ -98,7 +98,7 @@ class Unpool2dMaxFunctor<platform::GPUPlace, T> {
dim3 grid(blocks, 1);
KernelUnpool2dMax<
T><<<grid, threads, 0,
T, T2><<<grid, threads, 0,
reinterpret_cast<const platform::CUDADeviceContext&>(context)
.stream()>>>(nthreads, input_data, indices_data,
input_height, input_width, output_channels,
......@@ -108,8 +108,8 @@ class Unpool2dMaxFunctor<platform::GPUPlace, T> {
/*
* All tensors are in NCHW format.
*/
template <typename T>
class Unpool2dMaxGradFunctor<platform::GPUPlace, T> {
template <typename T, typename T2>
class Unpool2dMaxGradFunctor<platform::GPUPlace, T, T2> {
public:
void operator()(const platform::DeviceContext& context,
const framework::Tensor& input,
......@@ -124,7 +124,7 @@ class Unpool2dMaxGradFunctor<platform::GPUPlace, T> {
const int output_height = output.dims()[2];
const int output_width = output.dims()[3];
const T* input_data = input.data<T>();
const T* indices_data = indices.data<T>();
const T2 * indices_data = indices.data<T2>();
const T* output_data = output.data<T>();
const T* output_grad_data = output_grad.data<T>();
T* input_grad_data = input_grad->mutable_data<T>(context.GetPlace());
......@@ -134,7 +134,7 @@ class Unpool2dMaxGradFunctor<platform::GPUPlace, T> {
dim3 grid(blocks, 1);
KernelUnpool2dMaxGrad<
T><<<grid, threads, 0,
T, T2><<<grid, threads, 0,
reinterpret_cast<const platform::CUDADeviceContext&>(context)
.stream()>>>(
nthreads, input_data, indices_data,
......@@ -145,11 +145,11 @@ class Unpool2dMaxGradFunctor<platform::GPUPlace, T> {
}
};
template class Unpool2dMaxGradFunctor<platform::GPUPlace, float>;
template class Unpool2dMaxGradFunctor<platform::GPUPlace, double>;
template class Unpool2dMaxGradFunctor<platform::GPUPlace, float, int>;
template class Unpool2dMaxGradFunctor<platform::GPUPlace, double, int>;
template class Unpool2dMaxFunctor<platform::GPUPlace, float>;
template class Unpool2dMaxFunctor<platform::GPUPlace, double>;
template class Unpool2dMaxFunctor<platform::GPUPlace, float, int>;
template class Unpool2dMaxFunctor<platform::GPUPlace, double, int>;
} // namespace math
} // namespace operators
......
......@@ -19,7 +19,7 @@ namespace paddle {
namespace operators {
namespace math {
template <typename Place, typename T>
template <typename Place, typename T, typename T2>
class Unpool2dMaxFunctor {
public:
......@@ -29,7 +29,7 @@ class Unpool2dMaxFunctor {
framework::Tensor * output);
};
template <typename Place, class T>
template <typename Place, class T, typename T2>
class Unpool2dMaxGradFunctor {
public:
void operator()(const platform::DeviceContext& context,
......
......@@ -66,7 +66,15 @@ int OutputSize(int input_size, int ksize, int padding, int stride) {
}
class UnpoolOp : public framework::OperatorWithKernel {
public:
protected:
framework::OpKernelType GetKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<framework::Tensor>("X")->type()),
ctx.device_context());
}
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) of UnpoolOp"
......@@ -102,6 +110,14 @@ class UnpoolOp : public framework::OperatorWithKernel {
};
class UnpoolOpGrad : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<framework::Tensor>("X")->type()),
ctx.device_context());
}
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
......@@ -118,9 +134,9 @@ namespace ops = paddle::operators;
REGISTER_OP(unpool, ops::UnpoolOp, ops::Unpool2dOpMaker, unpool_grad,
ops::UnpoolOpGrad);
REGISTER_OP_CPU_KERNEL(unpool,
ops::UnpoolKernel<paddle::platform::CPUPlace, float>,
ops::UnpoolKernel<paddle::platform::CPUPlace, double>);
ops::UnpoolKernel<paddle::platform::CPUPlace, float, int>,
ops::UnpoolKernel<paddle::platform::CPUPlace, double, int>);
REGISTER_OP_CPU_KERNEL(unpool_grad,
ops::UnpoolGradKernel<paddle::platform::CPUPlace, float>,
ops::UnpoolGradKernel<paddle::platform::CPUPlace, double>);
ops::UnpoolGradKernel<paddle::platform::CPUPlace, float, int>,
ops::UnpoolGradKernel<paddle::platform::CPUPlace, double, int>);
......@@ -16,10 +16,10 @@
namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(unpool,
ops::UnpoolKernel<paddle::platform::GPUPlace, float>,
ops::UnpoolKernel<paddle::platform::GPUPlace, double>);
ops::UnpoolKernel<paddle::platform::GPUPlace, float, int>,
ops::UnpoolKernel<paddle::platform::GPUPlace, double, int>);
REGISTER_OP_GPU_KERNEL(unpool_grad,
ops::UnpoolGradKernel<paddle::platform::GPUPlace,
float>,
float, int>,
ops::UnpoolGradKernel<paddle::platform::GPUPlace,
double>);
double, int>);
......@@ -21,7 +21,7 @@ limitations under the License. */
namespace paddle {
namespace operators {
template <typename Place, typename T>
template <typename Place, typename T, typename T2>
class UnpoolKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
......@@ -37,12 +37,12 @@ class UnpoolKernel : public framework::OpKernel<T> {
math::SetConstant<Place, T> set_zero;
set_zero(context.device_context(), out, static_cast<T>(0));
}
math::Unpool2dMaxFunctor<Place, T> unpool2d_max_forward;
math::Unpool2dMaxFunctor<Place, T, T2> unpool2d_max_forward;
unpool2d_max_forward(context.device_context(), *in_x, *in_y, out);
}
};
template <typename Place, typename T>
template <typename Place, typename T, typename T2>
class UnpoolGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
......@@ -64,7 +64,7 @@ class UnpoolGradKernel : public framework::OpKernel<T> {
in_x_grad->mutable_data<T>(context.GetPlace());
zero(device_ctx, in_x_grad, static_cast<T>(0));
}
math::Unpool2dMaxGradFunctor<Place, T> unpool2d_max_backward;
math::Unpool2dMaxGradFunctor<Place, T, T2> unpool2d_max_backward;
unpool2d_max_backward(context.device_context(), *in_x, *in_y,
*out, *out_grad, in_x_grad);
}
......
......@@ -53,7 +53,7 @@ class TestUnpoolOp(OpTest):
output = self.Unpool2d_forward_naive(input, indices, self.ksize, \
self.strides, self.paddings).astype("float32")
self.inputs = {'X': input.astype('float32'),
'Y': indices.astype('int16')}
'Y': indices.astype('int32')}
self.attrs = {
'strides': self.strides,
'paddings': self.paddings,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册