提交 c218961a 编写于 作者: S sweetsky0901

modify for code review by qingqing

上级 ee4a5d21
......@@ -60,9 +60,9 @@ public:
void operator()(const platform::DeviceContext& context,
const framework::Tensor& input,
const framework::Tensor& indices,
framework::Tensor * input_grad,
const framework::Tensor& output,
const framework::Tensor& output_grad) {
const framework::Tensor& output_grad,
framework::Tensor * input_grad) {
const int batch_size = input.dims()[0];
const int input_height = input.dims()[2];
const int input_width = input.dims()[3];
......
......@@ -114,9 +114,9 @@ class Unpool2dMaxGradFunctor<platform::GPUPlace, T> {
void operator()(const platform::DeviceContext& context,
const framework::Tensor& input,
const framework::Tensor& indices,
framework::Tensor * input_grad,
const framework::Tensor& output,
const framework::Tensor& output_grad) {
const framework::Tensor& output_grad,
framework::Tensor * input_grad) {
const int batch_size = input.dims()[0];
const int input_height = input.dims()[2];
const int input_width = input.dims()[3];
......
......@@ -14,8 +14,6 @@ limitations under the License. */
#pragma once
#include "paddle/framework/tensor.h"
#include "paddle/platform/device_context.h"
#include "paddle/platform/hostdevice.h"
namespace paddle {
namespace operators {
......@@ -37,9 +35,9 @@ class Unpool2dMaxGradFunctor {
void operator()(const platform::DeviceContext& context,
const framework::Tensor& input,
const framework::Tensor& indices,
framework::Tensor * input_grad,
const framework::Tensor& output,
const framework::Tensor& output_grad);
const framework::Tensor& output_grad,
framework::Tensor * input_grad);
};
} // namespace math
} // namespace operators
......
......@@ -78,7 +78,7 @@ class UnpoolOp : public framework::OperatorWithKernel {
auto in_x_dims = ctx->GetInputDim("X");
auto in_y_dims = ctx->GetInputDim("Y");
std::string unpoolingtype =
std::string unpooling_type =
ctx->Attrs().Get<std::string>("unpoolingtype");
std::vector<int> ksize = ctx->Attrs().Get<std::vector<int>>("ksize");
std::vector<int> strides = ctx->Attrs().Get<std::vector<int>>("strides");
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册