提交 50c5e9b0 编写于 作者: S Sylwester Fraczek

reshape_2d used from ddim.h

test=develop
上级 55d6950a
...@@ -44,18 +44,6 @@ namespace ir { ...@@ -44,18 +44,6 @@ namespace ir {
GET_IR_NODE_FROM_SUBGRAPH(bn_saved_mean, bn_saved_mean, pattern_name); \ GET_IR_NODE_FROM_SUBGRAPH(bn_saved_mean, bn_saved_mean, pattern_name); \
GET_IR_NODE_FROM_SUBGRAPH(bn_saved_variance, bn_saved_variance, pattern_name) GET_IR_NODE_FROM_SUBGRAPH(bn_saved_variance, bn_saved_variance, pattern_name)
// reshape to two dimensions {A, B * C * ...}
DDim make_dims_2d(DDim dims) {
auto dims_count = dims.size();
PADDLE_ENFORCE_GT(dims_count, 0);
int size2 = 1;
for (int i = 1; i < dims_count; i++) {
size2 *= dims[i];
}
return make_ddim({dims[0], size2});
}
void recompute_bias_and_weights(const Scope* scope, void recompute_bias_and_weights(const Scope* scope,
ir::Node* conv_weight, // ir::Node* conv_weight, //
const ir::Node& bn_scale, // const ir::Node& bn_scale, //
...@@ -104,7 +92,7 @@ void recompute_bias_and_weights(const Scope* scope, ...@@ -104,7 +92,7 @@ void recompute_bias_and_weights(const Scope* scope,
// Re-compute weight of conv2d from BN // Re-compute weight of conv2d from BN
auto* weights = scope->FindVar(conv_weight->Name())->GetMutable<LoDTensor>(); auto* weights = scope->FindVar(conv_weight->Name())->GetMutable<LoDTensor>();
auto weights_shape = weights->dims(); auto weights_shape = weights->dims();
auto weights_shape_2d = make_dims_2d(weights_shape); auto weights_shape_2d = flatten_to_2d(weights_shape, 1);
EigenMatrixArrayMap weights_array_2d( EigenMatrixArrayMap weights_array_2d(
weights->mutable_data<float>(platform::CPUPlace()), weights_shape_2d[0], weights->mutable_data<float>(platform::CPUPlace()), weights_shape_2d[0],
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册