提交 1d9fd1c0 编写于 作者: Y Yang Yang

pass test_recognize_digits

上级 9d26f1a3
......@@ -60,8 +60,9 @@ void ConvOp::InferShape(framework::InferShapeContext* ctx) const {
"Due to the settings of paddings, filter_dims and "
"dilations, the output size is less than 0, please check "
"again.");
output_shape.push_back(OutputSize(in_dims[i + 2], filter_dims[i + 2],
dilations[i], paddings[i], strides[i]));
output_shape.push_back(ConvOutputSize(in_dims[i + 2], filter_dims[i + 2],
dilations[i], paddings[i],
strides[i]));
}
ctx->SetOutputDim("Output", framework::make_ddim(output_shape));
ctx->ShareLoD("Input", "Output");
......
......@@ -28,7 +28,7 @@ using Tensor = framework::Tensor;
// Base convolution operator definations for other conv
// like operators to reuse the implementation.
inline int OutputSize(int input_size, int filter_size, int dilation,
inline int ConvOutputSize(int input_size, int filter_size, int dilation,
int padding, int stride) {
const int dkernel = dilation * (filter_size - 1) + 1;
const int output_size = (input_size + 2 * padding - dkernel) / stride + 1;
......
......@@ -256,6 +256,10 @@ class ParallelDoGradOp : public framework::OperatorBase {
}
}
for (auto &s : Outputs(framework::GradVarName(kParameters))) {
if (s == "@EMPTY@") {
continue;
}
VLOG(3) << "Moving " << s;
CopyOrShare(*sub_scopes[0]->FindVar(s), place, scope.FindVar(s));
}
WaitOnPlaces(places);
......@@ -266,6 +270,9 @@ class ParallelDoGradOp : public framework::OperatorBase {
const std::vector<framework::Scope *> &sub_scopes,
const platform::PlaceList &places) const {
for (auto &s : Outputs(framework::GradVarName(kParameters))) {
if (s == "@EMPTY@") {
continue;
}
VLOG(3) << "Accumulating " << s;
if (s == framework::kEmptyVarName) continue;
std::string tmp_name;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册