提交 8e55736a 编写于 作者: Z zchen0211

deconv2d

上级 7eeaae16
...@@ -31,12 +31,14 @@ void Deconv2DOp::InferShape(framework::InferShapeContext* ctx) const { ...@@ -31,12 +31,14 @@ void Deconv2DOp::InferShape(framework::InferShapeContext* ctx) const {
std::vector<int> strides = ctx->Attrs().Get<std::vector<int>>("strides"); std::vector<int> strides = ctx->Attrs().Get<std::vector<int>>("strides");
std::vector<int> paddings = ctx->Attrs().Get<std::vector<int>>("paddings"); std::vector<int> paddings = ctx->Attrs().Get<std::vector<int>>("paddings");
for (int i = 0; i < paddings.size(); ++i) { for (size_t i = 0; i < paddings.size(); ++i) {
PADDLE_ENFORCE_EQ(paddings[i], 0, "No Padding allowed in deconv op."); PADDLE_ENFORCE_EQ(paddings[i], 0, "No Padding allowed in deconv op.");
} }
PADDLE_ENFORCE_EQ(in_dims.size(), 4, "Deconv2DOp input should be 4-D."); PADDLE_ENFORCE_EQ(in_dims.size(), 4,
PADDLE_ENFORCE_EQ(filter_dims.size(), 4, "Deconv2DOp filter should be 4-D."); "Deconv2DOp input should be 4-D tensor.");
PADDLE_ENFORCE_EQ(filter_dims.size(), 4,
"Deconv2DOp filter should be 4-D tensor.");
PADDLE_ENFORCE_EQ(in_dims[1], filter_dims[0], PADDLE_ENFORCE_EQ(in_dims[1], filter_dims[0],
"input and kernel input dimension should be equal."); "input and kernel input dimension should be equal.");
...@@ -52,14 +54,14 @@ Deconv2DOpMaker::Deconv2DOpMaker(framework::OpProto* proto, ...@@ -52,14 +54,14 @@ Deconv2DOpMaker::Deconv2DOpMaker(framework::OpProto* proto,
AddInput( AddInput(
"Input", "Input",
"The input tensor of deconvolution operator. " "The input tensor of deconvolution operator. "
"The format of input tensor is NMHW. Where N is batch size, M is the " "The format of input tensor is NCHW. Where N is batch size, C is the "
"number of input channels, H and W is the height and width of image."); "number of input channels, H and W is the height and width of image.");
AddInput("Filter", AddInput("Filter",
"The filter tensor of deconvolution operator." "The filter tensor of deconvolution operator."
"The format of the filter tensor is MCHW, where M is the number of " "The format of the filter tensor is MCHW, where C is the number of "
"input image channels, C is the number of output image channels, " "output image channels, M is the number of input image channels, "
"H and W is height and width of filter. " "H and W is height and width of filter. "
"We enforce groups number == 1 and padding == 0 in our " "We enforce groups number == 1 and padding == 0 in "
"deconvolution Scenario."); "deconvolution Scenario.");
AddOutput("Output", AddOutput("Output",
"The output tensor of deconvolution operator." "The output tensor of deconvolution operator."
......
...@@ -55,7 +55,7 @@ class GemmDeconv2DKernel : public framework::OpKernel<T> { ...@@ -55,7 +55,7 @@ class GemmDeconv2DKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
const Tensor* input = context.Input<Tensor>("Input"); const Tensor* input = context.Input<Tensor>("Input");
// filter will be reshaped, so we do not use constant pointer here // The filter will be reshaped, so it should not be constant pointer
Tensor filter = *context.Input<Tensor>("Filter"); Tensor filter = *context.Input<Tensor>("Filter");
Tensor* output = context.Output<Tensor>("Output"); Tensor* output = context.Output<Tensor>("Output");
...@@ -132,8 +132,8 @@ class GemmDeconvGrad2DKernel : public framework::OpKernel<T> { ...@@ -132,8 +132,8 @@ class GemmDeconvGrad2DKernel : public framework::OpKernel<T> {
const Tensor* output_grad = const Tensor* output_grad =
context.Input<Tensor>(framework::GradVarName("Output")); context.Input<Tensor>(framework::GradVarName("Output"));
// For filter, we do not use const pointer b/c we will do reshape // For filter, we do not use const pointer b/c we will do reshape,
// but we should avoid modifying its value // but we should avoid modifying its value.
Tensor filter = *context.Input<Tensor>("Filter"); Tensor filter = *context.Input<Tensor>("Filter");
Tensor* input_grad = Tensor* input_grad =
...@@ -142,7 +142,7 @@ class GemmDeconvGrad2DKernel : public framework::OpKernel<T> { ...@@ -142,7 +142,7 @@ class GemmDeconvGrad2DKernel : public framework::OpKernel<T> {
context.Output<Tensor>(framework::GradVarName("Filter")); context.Output<Tensor>(framework::GradVarName("Filter"));
std::vector<int> strides = context.Attr<std::vector<int>>("strides"); std::vector<int> strides = context.Attr<std::vector<int>>("strides");
// Actually, no paddings and groups allowed in deconv // Actually, no paddings and groups allowed in deconv.
std::vector<int> paddings = context.Attr<std::vector<int>>("paddings"); std::vector<int> paddings = context.Attr<std::vector<int>>("paddings");
int N = input->dims()[0]; int N = input->dims()[0];
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册