提交 8e55736a 编写于 作者: Z zchen0211

deconv2d

上级 7eeaae16
......@@ -31,12 +31,14 @@ void Deconv2DOp::InferShape(framework::InferShapeContext* ctx) const {
std::vector<int> strides = ctx->Attrs().Get<std::vector<int>>("strides");
std::vector<int> paddings = ctx->Attrs().Get<std::vector<int>>("paddings");
for (int i = 0; i < paddings.size(); ++i) {
for (size_t i = 0; i < paddings.size(); ++i) {
PADDLE_ENFORCE_EQ(paddings[i], 0, "No Padding allowed in deconv op.");
}
PADDLE_ENFORCE_EQ(in_dims.size(), 4, "Deconv2DOp input should be 4-D.");
PADDLE_ENFORCE_EQ(filter_dims.size(), 4, "Deconv2DOp filter should be 4-D.");
PADDLE_ENFORCE_EQ(in_dims.size(), 4,
"Deconv2DOp input should be 4-D tensor.");
PADDLE_ENFORCE_EQ(filter_dims.size(), 4,
"Deconv2DOp filter should be 4-D tensor.");
PADDLE_ENFORCE_EQ(in_dims[1], filter_dims[0],
"input and kernel input dimension should be equal.");
......@@ -52,14 +54,14 @@ Deconv2DOpMaker::Deconv2DOpMaker(framework::OpProto* proto,
AddInput(
"Input",
"The input tensor of deconvolution operator. "
"The format of input tensor is NMHW. Where N is batch size, M is the "
"The format of input tensor is NCHW. Where N is batch size, C is the "
"number of input channels, H and W is the height and width of image.");
AddInput("Filter",
"The filter tensor of deconvolution operator."
"The format of the filter tensor is MCHW, where M is the number of "
"input image channels, C is the number of output image channels, "
"The format of the filter tensor is MCHW, where C is the number of "
"output image channels, M is the number of input image channels, "
"H and W is height and width of filter. "
"We enforce groups number == 1 and padding == 0 in our "
"We enforce groups number == 1 and padding == 0 in "
"deconvolution Scenario.");
AddOutput("Output",
"The output tensor of deconvolution operator."
......
......@@ -55,7 +55,7 @@ class GemmDeconv2DKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
const Tensor* input = context.Input<Tensor>("Input");
// filter will be reshaped, so we do not use constant pointer here
// The filter will be reshaped, so it should not be constant pointer
Tensor filter = *context.Input<Tensor>("Filter");
Tensor* output = context.Output<Tensor>("Output");
......@@ -132,8 +132,8 @@ class GemmDeconvGrad2DKernel : public framework::OpKernel<T> {
const Tensor* output_grad =
context.Input<Tensor>(framework::GradVarName("Output"));
// For filter, we do not use const pointer b/c we will do reshape
// but we should avoid modifying its value
// For filter, we do not use const pointer b/c we will do reshape,
// but we should avoid modifying its value.
Tensor filter = *context.Input<Tensor>("Filter");
Tensor* input_grad =
......@@ -142,7 +142,7 @@ class GemmDeconvGrad2DKernel : public framework::OpKernel<T> {
context.Output<Tensor>(framework::GradVarName("Filter"));
std::vector<int> strides = context.Attr<std::vector<int>>("strides");
// Actually, no paddings and groups allowed in deconv
// Actually, no paddings and groups allowed in deconv.
std::vector<int> paddings = context.Attr<std::vector<int>>("paddings");
int N = input->dims()[0];
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册