提交 c218961a 编写于 作者: S sweetsky0901

modify for code review by qingqing

上级 ee4a5d21
...@@ -60,9 +60,9 @@ public: ...@@ -60,9 +60,9 @@ public:
void operator()(const platform::DeviceContext& context, void operator()(const platform::DeviceContext& context,
const framework::Tensor& input, const framework::Tensor& input,
const framework::Tensor& indices, const framework::Tensor& indices,
framework::Tensor * input_grad,
const framework::Tensor& output, const framework::Tensor& output,
const framework::Tensor& output_grad) { const framework::Tensor& output_grad,
framework::Tensor * input_grad) {
const int batch_size = input.dims()[0]; const int batch_size = input.dims()[0];
const int input_height = input.dims()[2]; const int input_height = input.dims()[2];
const int input_width = input.dims()[3]; const int input_width = input.dims()[3];
......
...@@ -114,9 +114,9 @@ class Unpool2dMaxGradFunctor<platform::GPUPlace, T> { ...@@ -114,9 +114,9 @@ class Unpool2dMaxGradFunctor<platform::GPUPlace, T> {
void operator()(const platform::DeviceContext& context, void operator()(const platform::DeviceContext& context,
const framework::Tensor& input, const framework::Tensor& input,
const framework::Tensor& indices, const framework::Tensor& indices,
framework::Tensor * input_grad,
const framework::Tensor& output, const framework::Tensor& output,
const framework::Tensor& output_grad) { const framework::Tensor& output_grad,
framework::Tensor * input_grad) {
const int batch_size = input.dims()[0]; const int batch_size = input.dims()[0];
const int input_height = input.dims()[2]; const int input_height = input.dims()[2];
const int input_width = input.dims()[3]; const int input_width = input.dims()[3];
......
...@@ -14,8 +14,6 @@ limitations under the License. */ ...@@ -14,8 +14,6 @@ limitations under the License. */
#pragma once #pragma once
#include "paddle/framework/tensor.h" #include "paddle/framework/tensor.h"
#include "paddle/platform/device_context.h"
#include "paddle/platform/hostdevice.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -37,9 +35,9 @@ class Unpool2dMaxGradFunctor { ...@@ -37,9 +35,9 @@ class Unpool2dMaxGradFunctor {
void operator()(const platform::DeviceContext& context, void operator()(const platform::DeviceContext& context,
const framework::Tensor& input, const framework::Tensor& input,
const framework::Tensor& indices, const framework::Tensor& indices,
framework::Tensor * input_grad,
const framework::Tensor& output, const framework::Tensor& output,
const framework::Tensor& output_grad); const framework::Tensor& output_grad,
framework::Tensor * input_grad);
}; };
} // namespace math } // namespace math
} // namespace operators } // namespace operators
......
...@@ -78,7 +78,7 @@ class UnpoolOp : public framework::OperatorWithKernel { ...@@ -78,7 +78,7 @@ class UnpoolOp : public framework::OperatorWithKernel {
auto in_x_dims = ctx->GetInputDim("X"); auto in_x_dims = ctx->GetInputDim("X");
auto in_y_dims = ctx->GetInputDim("Y"); auto in_y_dims = ctx->GetInputDim("Y");
std::string unpoolingtype = std::string unpooling_type =
ctx->Attrs().Get<std::string>("unpoolingtype"); ctx->Attrs().Get<std::string>("unpoolingtype");
std::vector<int> ksize = ctx->Attrs().Get<std::vector<int>>("ksize"); std::vector<int> ksize = ctx->Attrs().Get<std::vector<int>>("ksize");
std::vector<int> strides = ctx->Attrs().Get<std::vector<int>>("strides"); std::vector<int> strides = ctx->Attrs().Get<std::vector<int>>("strides");
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册