提交 92baa885 编写于 作者: W wanghaoshuang

Fix code style

上级 e82f1008
...@@ -23,7 +23,6 @@ class BlockExpandOp : public framework::OperatorWithKernel { ...@@ -23,7 +23,6 @@ class BlockExpandOp : public framework::OperatorWithKernel {
protected: protected:
void InferShape(framework::InferShapeContext* ctx) const override { void InferShape(framework::InferShapeContext* ctx) const override {
using namespace framework;
PADDLE_ENFORCE(ctx->HasInput("X"), PADDLE_ENFORCE(ctx->HasInput("X"),
"Input of BlockExpandOp should not be null."); "Input of BlockExpandOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"), PADDLE_ENFORCE(ctx->HasOutput("Out"),
...@@ -142,7 +141,6 @@ class BlockExpandGradOp : public framework::OperatorWithKernel { ...@@ -142,7 +141,6 @@ class BlockExpandGradOp : public framework::OperatorWithKernel {
protected: protected:
void InferShape(framework::InferShapeContext* ctx) const override { void InferShape(framework::InferShapeContext* ctx) const override {
using namespace framework;
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null"); PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null");
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")), PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
"Input(Out@GRAD) shouldn't be null."); "Input(Out@GRAD) shouldn't be null.");
......
...@@ -23,6 +23,9 @@ ...@@ -23,6 +23,9 @@
namespace paddle { namespace paddle {
namespace operators { namespace operators {
using Tensor = framework::Tensor;
using LoDTensor = framework::LoDTensor;
inline int get_output_size(int img_size, int block_size, int stride, inline int get_output_size(int img_size, int block_size, int stride,
int padding) { int padding) {
return (1 + (img_size + 2 * padding - block_size + stride - 1) / stride); return (1 + (img_size + 2 * padding - block_size + stride - 1) / stride);
...@@ -32,7 +35,6 @@ template <typename Place, typename T> ...@@ -32,7 +35,6 @@ template <typename Place, typename T>
class BlockExpandKernel : public framework::OpKernel<T> { class BlockExpandKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
using namespace framework;
const Tensor* in = ctx.Input<Tensor>("X"); const Tensor* in = ctx.Input<Tensor>("X");
LoDTensor* out = ctx.Output<LoDTensor>("Out"); LoDTensor* out = ctx.Output<LoDTensor>("Out");
out->mutable_data<T>(ctx.GetPlace()); out->mutable_data<T>(ctx.GetPlace());
...@@ -89,11 +91,10 @@ template <typename Place, typename T> ...@@ -89,11 +91,10 @@ template <typename Place, typename T>
class BlockExpandGradKernel : public framework::OpKernel<T> { class BlockExpandGradKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
using namespace framework;
auto* in = ctx.Input<Tensor>("X"); auto* in = ctx.Input<Tensor>("X");
Tensor* d_out = Tensor* d_out =
const_cast<Tensor*>(ctx.Input<Tensor>(framework::GradVarName("Out"))); const_cast<Tensor*>(ctx.Input<Tensor>(framework::GradVarName("Out")));
auto* d_x = ctx.Output<Tensor>(GradVarName("X")); auto* d_x = ctx.Output<Tensor>(framework::GradVarName("X"));
d_x->mutable_data<T>(ctx.GetPlace()); d_x->mutable_data<T>(ctx.GetPlace());
auto x_v = framework::EigenVector<T>::Flatten(*d_x); auto x_v = framework::EigenVector<T>::Flatten(*d_x);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册