提交 97dc451f 编写于 作者: Y Yang Yang

clean up

上级 fccbc2fc
...@@ -182,8 +182,6 @@ static const Tensor* GetTensorFromVar(const Variable* var) { ...@@ -182,8 +182,6 @@ static const Tensor* GetTensorFromVar(const Variable* var) {
const Tensor* t = nullptr; const Tensor* t = nullptr;
if (var->IsType<LoDTensor>()) { if (var->IsType<LoDTensor>()) {
t = &(var->Get<LoDTensor>()); t = &(var->Get<LoDTensor>());
} else if (var->IsType<Tensor>()) {
t = &(var->Get<Tensor>());
} else if (var->IsType<SelectedRows>()) { } else if (var->IsType<SelectedRows>()) {
t = &(var->Get<SelectedRows>().value()); t = &(var->Get<SelectedRows>().value());
} else { } else {
...@@ -197,8 +195,6 @@ static Tensor* GetMutableTensorFromVar(Variable* var) { ...@@ -197,8 +195,6 @@ static Tensor* GetMutableTensorFromVar(Variable* var) {
Tensor* t = nullptr; Tensor* t = nullptr;
if (var->IsType<LoDTensor>()) { if (var->IsType<LoDTensor>()) {
t = var->GetMutable<LoDTensor>(); t = var->GetMutable<LoDTensor>();
} else if (var->IsType<Tensor>()) {
t = var->GetMutable<Tensor>();
} else if (var->IsType<SelectedRows>()) { } else if (var->IsType<SelectedRows>()) {
t = var->GetMutable<SelectedRows>()->mutable_value(); t = var->GetMutable<SelectedRows>()->mutable_value();
} else { } else {
...@@ -362,8 +358,6 @@ class RuntimeInferShapeContext : public InferShapeContext { ...@@ -362,8 +358,6 @@ class RuntimeInferShapeContext : public InferShapeContext {
Variable* var = scope_.FindVar(name); Variable* var = scope_.FindVar(name);
if (var->IsType<LoDTensor>()) { if (var->IsType<LoDTensor>()) {
return var->Get<LoDTensor>().dims(); return var->Get<LoDTensor>().dims();
} else if (var->IsType<Tensor>()) {
return var->Get<Tensor>().dims();
} else if (var->IsType<SelectedRows>()) { } else if (var->IsType<SelectedRows>()) {
return var->Get<SelectedRows>().GetCompleteDims(); return var->Get<SelectedRows>().GetCompleteDims();
} else { } else {
...@@ -376,8 +370,6 @@ class RuntimeInferShapeContext : public InferShapeContext { ...@@ -376,8 +370,6 @@ class RuntimeInferShapeContext : public InferShapeContext {
Variable* var = scope_.FindVar(name); Variable* var = scope_.FindVar(name);
if (var->IsType<LoDTensor>()) { if (var->IsType<LoDTensor>()) {
var->GetMutable<LoDTensor>()->Resize(dim); var->GetMutable<LoDTensor>()->Resize(dim);
} else if (var->IsType<Tensor>()) {
var->GetMutable<Tensor>()->Resize(dim);
} else if (var->IsType<SelectedRows>()) { } else if (var->IsType<SelectedRows>()) {
var->GetMutable<SelectedRows>()->set_height(dim[0]); var->GetMutable<SelectedRows>()->set_height(dim[0]);
} else { } else {
......
...@@ -34,8 +34,6 @@ class ElementwiseOp : public framework::OperatorWithKernel { ...@@ -34,8 +34,6 @@ class ElementwiseOp : public framework::OperatorWithKernel {
auto x_dim = ctx->GetInputDim("X"); auto x_dim = ctx->GetInputDim("X");
auto y_dim = ctx->GetInputDim("Y"); auto y_dim = ctx->GetInputDim("Y");
LOG(INFO) << x_dim;
LOG(INFO) << y_dim;
PADDLE_ENFORCE_GE(x_dim.size(), y_dim.size(), PADDLE_ENFORCE_GE(x_dim.size(), y_dim.size(),
"Rank of first input must >= rank of second input."); "Rank of first input must >= rank of second input.");
ctx->SetOutputDim("Out", x_dim); ctx->SetOutputDim("Out", x_dim);
...@@ -119,9 +117,6 @@ class ElementwiseOpGrad : public framework::OperatorWithKernel { ...@@ -119,9 +117,6 @@ class ElementwiseOpGrad : public framework::OperatorWithKernel {
auto x_dims = ctx->GetInputDim("X"); auto x_dims = ctx->GetInputDim("X");
auto y_dims = ctx->GetInputDim("Y"); auto y_dims = ctx->GetInputDim("Y");
auto out_dims = ctx->GetInputDim(framework::GradVarName("Out")); auto out_dims = ctx->GetInputDim(framework::GradVarName("Out"));
LOG(INFO) << x_dims;
LOG(INFO) << y_dims;
LOG(INFO) << out_dims;
PADDLE_ENFORCE_GE(x_dims.size(), y_dims.size(), PADDLE_ENFORCE_GE(x_dims.size(), y_dims.size(),
"Rank of first input must >= rank of second input."); "Rank of first input must >= rank of second input.");
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册