提交 1d9fd1c0 编写于 作者: Y Yang Yang

pass test_recognize_digits

上级 9d26f1a3
...@@ -60,8 +60,9 @@ void ConvOp::InferShape(framework::InferShapeContext* ctx) const { ...@@ -60,8 +60,9 @@ void ConvOp::InferShape(framework::InferShapeContext* ctx) const {
"Due to the settings of paddings, filter_dims and " "Due to the settings of paddings, filter_dims and "
"dilations, the output size is less than 0, please check " "dilations, the output size is less than 0, please check "
"again."); "again.");
output_shape.push_back(OutputSize(in_dims[i + 2], filter_dims[i + 2], output_shape.push_back(ConvOutputSize(in_dims[i + 2], filter_dims[i + 2],
dilations[i], paddings[i], strides[i])); dilations[i], paddings[i],
strides[i]));
} }
ctx->SetOutputDim("Output", framework::make_ddim(output_shape)); ctx->SetOutputDim("Output", framework::make_ddim(output_shape));
ctx->ShareLoD("Input", "Output"); ctx->ShareLoD("Input", "Output");
......
...@@ -28,7 +28,7 @@ using Tensor = framework::Tensor; ...@@ -28,7 +28,7 @@ using Tensor = framework::Tensor;
// Base convolution operator definations for other conv // Base convolution operator definations for other conv
// like operators to reuse the implementation. // like operators to reuse the implementation.
inline int OutputSize(int input_size, int filter_size, int dilation, inline int ConvOutputSize(int input_size, int filter_size, int dilation,
int padding, int stride) { int padding, int stride) {
const int dkernel = dilation * (filter_size - 1) + 1; const int dkernel = dilation * (filter_size - 1) + 1;
const int output_size = (input_size + 2 * padding - dkernel) / stride + 1; const int output_size = (input_size + 2 * padding - dkernel) / stride + 1;
......
...@@ -256,6 +256,10 @@ class ParallelDoGradOp : public framework::OperatorBase { ...@@ -256,6 +256,10 @@ class ParallelDoGradOp : public framework::OperatorBase {
} }
} }
for (auto &s : Outputs(framework::GradVarName(kParameters))) { for (auto &s : Outputs(framework::GradVarName(kParameters))) {
if (s == "@EMPTY@") {
continue;
}
VLOG(3) << "Moving " << s;
CopyOrShare(*sub_scopes[0]->FindVar(s), place, scope.FindVar(s)); CopyOrShare(*sub_scopes[0]->FindVar(s), place, scope.FindVar(s));
} }
WaitOnPlaces(places); WaitOnPlaces(places);
...@@ -266,6 +270,9 @@ class ParallelDoGradOp : public framework::OperatorBase { ...@@ -266,6 +270,9 @@ class ParallelDoGradOp : public framework::OperatorBase {
const std::vector<framework::Scope *> &sub_scopes, const std::vector<framework::Scope *> &sub_scopes,
const platform::PlaceList &places) const { const platform::PlaceList &places) const {
for (auto &s : Outputs(framework::GradVarName(kParameters))) { for (auto &s : Outputs(framework::GradVarName(kParameters))) {
if (s == "@EMPTY@") {
continue;
}
VLOG(3) << "Accumulating " << s; VLOG(3) << "Accumulating " << s;
if (s == framework::kEmptyVarName) continue; if (s == framework::kEmptyVarName) continue;
std::string tmp_name; std::string tmp_name;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册