未验证 提交 3bd1d22a 编写于 作者: C chengduo 提交者: GitHub

Enhance fused_elementwise_activation_op (#12837)

* Enhance the function of fused_elementwise_activation_op

* enhance unit test

* Clean Code And Add Doc

* Add compound functors

* Fix doc and enhance unit test

* define Dx and Dy for d_binary_func

* add mul_scale

* add mul_scale

* add elementwise_mul

* code refine

* code refine

* add doc

* add  AsIntermediate
上级 a615ad46
...@@ -12,14 +12,60 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,14 +12,60 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/fused_elemwise_activation_op.h"
#include <string> #include <string>
#include <vector> #include <vector>
#include "paddle/fluid/operators/fused_elemwise_activation_op.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
/*
* Whether the compound function is Unary(Binary(X, Y)).
* For Unary(Binary(X, Y)), the intermediate_out's shape is the same the final
* out.
*/
static bool IsUnaryCompound(const std::vector<std::string> &functor_list) {
PADDLE_ENFORCE_EQ(functor_list.size(), 2);
static std::unordered_set<std::string> binary_fun = {
"elementwise_add", "elementwise_mul", "elementwise_add_grad",
"elementwise_mul_grad"};
return binary_fun.count(functor_list[1]) != 0;
}
/*
* Whether the Input(X) could be absent.
*/
static bool InputXCanBeAbsent(const std::vector<std::string> &functor_list) {
PADDLE_ENFORCE_EQ(functor_list.size(), 2);
static std::unordered_set<std::string> binary_fun = {"elementwise_add_grad"};
return binary_fun.count(functor_list[0]) != 0 ||
binary_fun.count(functor_list[1]) != 0;
}
/*
* Whether the compound function is supported.
* For Unary(Binary(X, Y)), the intermediate_out's shape is the same the final
* out.
*/
static bool IsSupportedCompound(const std::vector<std::string> &functors) {
static std::unordered_set<std::string> unary_fun = {"scale", "relu"};
static std::unordered_set<std::string> binary_fun = {"elementwise_add",
"elementwise_mul"};
std::string unary_fun_str;
if (binary_fun.count(functors[0])) {
unary_fun_str = functors[1];
} else if (binary_fun.count(functors[1])) {
unary_fun_str = functors[0];
} else {
PADDLE_THROW("%s and %s are not included in fused_list.", functors[0],
functors[1]);
}
PADDLE_ENFORCE_EQ(unary_fun.count(unary_fun_str), 1,
"%s is not included in fused_list.", unary_fun_str);
return true;
}
class FusedElemwiseActivationOp : public framework::OperatorWithKernel { class FusedElemwiseActivationOp : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
...@@ -37,11 +83,44 @@ class FusedElemwiseActivationOp : public framework::OperatorWithKernel { ...@@ -37,11 +83,44 @@ class FusedElemwiseActivationOp : public framework::OperatorWithKernel {
auto x_dim = ctx->GetInputDim("X"); auto x_dim = ctx->GetInputDim("X");
auto y_dim = ctx->GetInputDim("Y"); auto y_dim = ctx->GetInputDim("Y");
PADDLE_ENFORCE_GE(x_dim.size(), y_dim.size(),
"Rank of first input must >= rank of second input.");
ctx->SetOutputDim("Out", x_dim); // Whether the shape of Y is a continuous subsequence of X,
ctx->ShareLoD("X", /*->*/ "Out"); // For more information please refer to the op's introduction.
bool bcast_y = x_dim.size() >= y_dim.size();
if (x_dim.size() == y_dim.size()) {
for (int i = 0; i < x_dim.size(); ++i) {
if (x_dim[i] < y_dim[i]) {
bcast_y = false;
break;
}
}
}
auto &out_dim = bcast_y ? x_dim : y_dim;
std::string out_lod = bcast_y ? "X" : "Y";
if (ctx->Attrs().Get<bool>("keep_intermediate_value")) {
PADDLE_ENFORCE(ctx->HasOutput("IntermediateOut"),
"Output(IntermediateOut) of FusedElemwiseActivationOp "
"should not be null.");
if (IsUnaryCompound(
ctx->Attrs().Get<std::vector<std::string>>("functor_list"))) {
// for Unary(Binary(X, Y)), the shape and lod of out and
// intermediate_out are the same.
ctx->SetOutputDim("IntermediateOut", out_dim);
// set the lod of intermediate_out
ctx->ShareLoD(out_lod, /*->*/ "IntermediateOut");
} else {
// for Binary(X, Unary(Y)), the shape and lod of Y and
// intermediate_out are the same.
ctx->SetOutputDim("IntermediateOut", y_dim);
// set the lod of intermediate_out
ctx->ShareLoD("Y", /*->*/ "IntermediateOut");
}
}
ctx->SetOutputDim("Out", out_dim);
ctx->ShareLoD(out_lod, /*->*/ "Out");
} }
protected: protected:
...@@ -59,29 +138,42 @@ class FusedElemwiseActivationOp : public framework::OperatorWithKernel { ...@@ -59,29 +138,42 @@ class FusedElemwiseActivationOp : public framework::OperatorWithKernel {
class FusedElemwiseActivationMaker : public framework::OpProtoAndCheckerMaker { class FusedElemwiseActivationMaker : public framework::OpProtoAndCheckerMaker {
public: public:
void Make() override { void Make() override {
AddInput("X", "(vector<Tensor>)"); AddInput(
AddInput("Y", "(vector<Tensor>)"); "X",
AddOutput("Out", "vector<Tensor>"); "(Tensor) The input tensor of fused_elemwise_activation operator.");
AddInput(
"Y",
"(Tensor) The input tensor of fused_elemwise_activation operator.");
AddOutput("Out",
"vector<Tensor> The output tensor of fused_elemwise_activation "
"operator.");
AddOutput("IntermediateOut",
"Tensor The IntermediateOut tensor of fused_elemwise_activation "
"operator.")
.AsIntermediate();
AddAttr<int>("axis", AddAttr<int>("axis",
"axis is used by elementwise_op, the default value is -1.") "axis is used by elementwise_op, the default value is -1.")
.SetDefault(-1); .SetDefault(-1);
AddAttr<float>("scale", AddAttr<float>("scale",
"scale is used by scale_op, the default value is 0.0.") "scale is used by scale_op, the default value is 0.0.")
.SetDefault(0.0); .SetDefault(0.0);
AddAttr<bool>("recomputation", AddAttr<bool>(
"Whether to recompute the Out." "recomputation",
"fused_elemwise_activation_grad has two methods to get the " "Whether to recompute the Out."
"dx and dy, one " "The computation of fused_elemwise_activation_grad has two methods to "
"is to use the 'Out', and the other is not to use it. " "get the dx and dy, one is to use the 'Out', and the other is not. "
"The former method will save the time of recomputing the " "The former method will save the time of recomputing the 'Out', but it "
"'Out', but it must occupy the memory to store the 'out'. " "must occupy the memory to store the 'out'. While, the later method "
"While, the later method can avoid occupying the memory, " "can avoid occupying the memory, but it must recompute the 'Out'. "
"but it must recompute the 'Out'. The default value is true.") "It is useful for Unary(Binary(X, Y)). The default value is true.")
.SetDefault(true); .SetDefault(true);
AddAttr<bool>("keep_intermediate_value",
"Whether to save the intermediate_out.")
.SetDefault(false);
AddAttr<std::vector<std::string>>("functor_list", AddAttr<std::vector<std::string>>("functor_list",
"The functors that should be fused.") "The functors that should be fused.")
.AddCustomChecker([&](const std::vector<std::string> &functor_list) { .AddCustomChecker([&](const std::vector<std::string> &functor_list) {
PADDLE_ENFORCE(ValidCheck(functor_list)); PADDLE_ENFORCE(IsSupportedCompound(functor_list));
}); });
AddComment(R"DOC( AddComment(R"DOC(
...@@ -93,30 +185,38 @@ operators (elementwise_op and activation_op): ...@@ -93,30 +185,38 @@ operators (elementwise_op and activation_op):
Z = Binary(X, Unary(Y)) Z = Binary(X, Unary(Y))
Z = Unary(Binary(X, Y)) Z = Unary(Binary(X, Y))
The attributions of activation_op can be get from fused_elemwise_activation_op's There are two cases for this operator:
attributions. functor_list records the functors to be fused, for example
"scale,elementwise_add".
)DOC"); 1. The shape of $Y$ and $X$ is the same.
} 2. The shape of $Y$ is a continuous subsequence of $X$ or the shape of $X$ is a continuous subsequence of $Y$.
private: For case 2 (assume that the shape of $Y$ is a continuous subsequence of $X$ ):
bool ValidCheck(const std::vector<std::string> &functors) {
std::unordered_set<std::string> unary_fun = {"scale", "relu"};
std::unordered_set<std::string> binary_fun = {"elementwise_add"};
std::string unary_fun_str; 1. Broadcast $Y$ to match the shape of $X$, where $axis$ is the start dimension index
if (binary_fun.count(functors[0])) { for broadcasting $Y$ onto $X$.
unary_fun_str = functors[1]; 2. If $axis$ is -1 (default), $axis = rank(X) - rank(Y)$.
} else if (binary_fun.count(functors[1])) { 3. The trailing dimensions of size 1 for $Y$ will be ignored for the consideration of
unary_fun_str = functors[0]; subsequence, such as shape(Y) = (2, 1) => (2).
} else {
PADDLE_THROW("%s and %s are not included in fused_list.", functors[0], For example:
functors[1]);
} .. code-block:: python
PADDLE_ENFORCE_EQ(unary_fun.count(unary_fun_str), 1,
"%s is not included in fused_list.", unary_fun_str); shape(X) = (2, 3, 4, 5), shape(Y) = (,)
return true; shape(X) = (2, 3, 4, 5), shape(Y) = (5,)
shape(X) = (2, 3, 4, 5), shape(Y) = (4, 5), with axis=-1(default) or axis=2
shape(X) = (2, 3, 4, 5), shape(Y) = (3, 4), with axis=1
shape(X) = (2, 3, 4, 5), shape(Y) = (2), with axis=0
shape(X) = (2, 3, 4, 5), shape(Y) = (2, 1), with axis=0
The inputs $X$ and $Y$ can carry the different LoD information.
But the output only shares the LoD information with the one whose shape is the same with Out.
The attributions of activation_op can be get from fused_elemwise_activation_op's.
The functor_list records the functions to be fused, for example
["scale", "elementwise_add"].
)DOC");
} }
}; };
...@@ -141,6 +241,7 @@ class FusedElemwiseActivationGradMaker ...@@ -141,6 +241,7 @@ class FusedElemwiseActivationGradMaker
op_desc_ptr->SetInput(framework::GradVarName(output_param), op_desc_ptr->SetInput(framework::GradVarName(output_param),
this->OutputGrad(output_param)); this->OutputGrad(output_param));
} }
op_desc_ptr->SetAttrMap(this->Attrs()); op_desc_ptr->SetAttrMap(this->Attrs());
std::vector<std::string> functor_names = std::vector<std::string> functor_names =
...@@ -158,40 +259,59 @@ class FusedElemwiseActivationOpGrad : public framework::OperatorWithKernel { ...@@ -158,40 +259,59 @@ class FusedElemwiseActivationOpGrad : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override { void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null");
PADDLE_ENFORCE(ctx->HasInput("Y"), "Input(Y) should not be null");
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")), PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
"Input(Out@GRAD) should not be null"); "Input(Out@Grad) should not be null");
if (ctx->Attrs().Get<bool>("keep_intermediate_value")) {
auto x_dims = ctx->GetInputDim("X"); PADDLE_ENFORCE(ctx->HasInput("IntermediateOut"),
auto y_dims = ctx->GetInputDim("Y"); "Input(IntermediateOut) should not be null");
auto out_dims = ctx->GetInputDim(framework::GradVarName("Out")); } else {
PADDLE_ENFORCE_EQ(ctx->Inputs(framework::GradVarName("Out")).size(), 1);
PADDLE_ENFORCE_GE(x_dims.size(), y_dims.size(), }
"Rank of first input must >= rank of second input.");
auto funtor_list =
ctx->Attrs().Get<std::vector<std::string>>("functor_list");
auto x_grad_name = framework::GradVarName("X"); auto x_grad_name = framework::GradVarName("X");
auto y_grad_name = framework::GradVarName("Y"); auto y_grad_name = framework::GradVarName("Y");
if (ctx->HasOutput(x_grad_name)) { if (ctx->HasOutput(x_grad_name)) {
ctx->SetOutputDim(x_grad_name, x_dims); if (ctx->HasInputs("X")) {
ctx->SetOutputDim(x_grad_name, ctx->GetInputDim("X"));
ctx->ShareLoD("X", x_grad_name);
} else {
// Node: If "X" is absence, the shape of Y should be a continuous
// subsequence of X, if not, we could not infer the shape of dx.
// Currently, only when Binary is elementwise_add or elementwise_sub,
// the "X" could be absent.
PADDLE_ENFORCE(InputXCanBeAbsent(funtor_list),
"Only when BinaryFunctor is elementwise_add, the 'X' "
"could be absent.");
// For Unary(Binary(X, Y)), IntermediateOut should not be empty.
if (IsUnaryCompound(funtor_list)) {
PADDLE_ENFORCE(
ctx->HasInputs("IntermediateOut"),
"If the compound_functor is Unary(Binary(X, Y)) and Binary "
"is elementwise_add, the intermediate_out must be not absent.");
}
ctx->SetOutputDim(x_grad_name,
ctx->GetInputDim(framework::GradVarName("Out")));
ctx->ShareLoD(framework::GradVarName("Out"), x_grad_name);
}
} }
if (ctx->HasOutput(y_grad_name)) { if (ctx->HasOutput(y_grad_name)) {
ctx->SetOutputDim(y_grad_name, y_dims); PADDLE_ENFORCE(ctx->HasInput("Y"), "Input(Y) should not be null");
ctx->SetOutputDim(y_grad_name, ctx->GetInputDim("Y"));
ctx->ShareLoD("Y", y_grad_name);
} }
} }
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext &ctx) const override {
auto input_data_type_index = ctx.Input<framework::Tensor>("X")->type(); // PADDLE_ENFORCE(ctx->HasInput("Y"), "Input(Y) should not be null");
PADDLE_ENFORCE_EQ(input_data_type_index, auto input_data_type_index = ctx.Input<framework::Tensor>("Y")->type();
ctx.Input<framework::Tensor>("Y")->type(),
"The element's type of input should be the same.");
PADDLE_ENFORCE_EQ(
input_data_type_index,
ctx.Input<framework::Tensor>(framework::GradVarName("Out"))->type(),
"The element's type of input should be the same.");
auto input_data_type = framework::ToDataType(input_data_type_index); auto input_data_type = framework::ToDataType(input_data_type_index);
return framework::OpKernelType(input_data_type, ctx.GetPlace()); return framework::OpKernelType(input_data_type, ctx.GetPlace());
} }
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <string>
#include <unordered_set>
#include <vector>
namespace paddle {
namespace operators {
namespace math {
template <typename T, typename BinaryFunctor, typename UnaryFunctor>
struct BinaryCompoundFunctor {
BinaryCompoundFunctor(const BinaryFunctor func1, const UnaryFunctor func2)
: func1_(func1), func2_(func2) {}
// Z = BinaryFunctor(X, UnaryFunctor(Y))
inline HOSTDEVICE T GetOut(T x, T y) { return func1_(x, func2_(y)); }
inline HOSTDEVICE T GetOutUseIntermediateOut(T x, T intermediat_out) {
return func1_(x, intermediat_out);
}
inline HOSTDEVICE T GetIntermediateOut(T x, T y) { return func2_(y); }
BinaryFunctor func1_;
UnaryFunctor func2_;
};
template <typename T, typename UnaryFunctor, typename BinaryFunctor>
struct UnaryCompoundFunctor {
UnaryCompoundFunctor(const UnaryFunctor func1, const BinaryFunctor func2)
: func1_(func1), func2_(func2) {}
// Z = UnaryFunctor(BinaryFunctor(X, Y))
inline HOSTDEVICE T GetOut(T x, T y) { return func1_(func2_(x, y)); }
inline HOSTDEVICE T GetOutUseIntermediateOut(T x, T intermediat_out) {
return func1_(intermediat_out);
}
inline HOSTDEVICE T GetIntermediateOut(T x, T y) { return func2_(x, y); }
UnaryFunctor func1_;
BinaryFunctor func2_;
};
// FIXME(zcd): DBinaryFun and DUnaryFun have to method to get
// the dx, one is to use the 'out', and the other is not to use it.
// the former method will save the time of recomputing the
// 'out', but it must occupy the memory to store the 'out'.
// While the later method can avoid occupying this memory,
// but it must recompute the 'out'.
template <typename T, typename DBinaryFun, typename UnaryFun>
struct BinaryCompoundGradDxFunctor {
BinaryCompoundGradDxFunctor(const DBinaryFun &d_binary_fun,
const UnaryFun &unary_fun)
: d_binary_fun_(d_binary_fun), unary_fun_(unary_fun) {}
inline HOSTDEVICE T operator()(T x, T y, T out, T dout) {
return dout * d_binary_fun_.Dx(x, unary_fun_(y));
}
inline HOSTDEVICE T operator()(T x, T y, T intermediate_out, T out, T dout) {
return dout * d_binary_fun_.Dx(x, intermediate_out);
}
private:
DBinaryFun d_binary_fun_;
UnaryFun unary_fun_;
};
template <typename T, typename DBinaryFun, typename UnaryFun,
typename DUnaryFun>
struct BinaryCompoundGradDyFunctor {
BinaryCompoundGradDyFunctor(const DBinaryFun &d_binary_fun,
const UnaryFun &unary_fun,
const DUnaryFun &d_unary_fun)
: d_binary_fun_(d_binary_fun),
unary_fun_(unary_fun),
d_unary_fun_(d_unary_fun) {}
inline HOSTDEVICE T operator()(T x, T y, T out, T dout) {
return dout * d_binary_fun_.Dy(x, unary_fun_(y)) * d_unary_fun_(y);
}
inline HOSTDEVICE T operator()(T x, T y, T intermediate_out, T out, T dout) {
return dout * d_binary_fun_.Dy(x, intermediate_out) *
d_unary_fun_(y, intermediate_out);
}
private:
DBinaryFun d_binary_fun_;
UnaryFun unary_fun_;
DUnaryFun d_unary_fun_;
};
template <typename T, typename DUnaryFun, typename BinaryFun,
typename DBinaryFun, bool Recomputation = true>
struct UnaryCompoundGradDxFunctor {
UnaryCompoundGradDxFunctor(const DUnaryFun &d_unary_fun,
const BinaryFun &binary_fun,
const DBinaryFun &d_binary_fun)
: d_unary_fun_(d_unary_fun),
binary_fun_(binary_fun),
d_binary_fun_(d_binary_fun) {}
inline HOSTDEVICE T operator()(T x, T y, T out, T dout) {
T base;
if (Recomputation) {
base = dout * d_unary_fun_(binary_fun_(x, y));
} else {
base = dout * d_unary_fun_(binary_fun_(x, y), out);
}
return base * d_binary_fun_.Dx(x, y);
}
inline HOSTDEVICE T operator()(T x, T y, T intermediate_out, T out, T dout) {
T base;
if (Recomputation) {
base = dout * d_unary_fun_(intermediate_out);
} else {
base = dout * d_unary_fun_(intermediate_out, out);
}
return base * d_binary_fun_.Dx(x, y);
}
private:
DUnaryFun d_unary_fun_;
BinaryFun binary_fun_;
DBinaryFun d_binary_fun_;
};
template <typename T, typename DUnaryFun, typename BinaryFun,
typename DBinaryFun, bool Recomputation = true>
struct UnaryCompoundGradDyFunctor {
UnaryCompoundGradDyFunctor(const DUnaryFun &d_unary_fun,
const BinaryFun &binary_fun,
const DBinaryFun &d_binary_fun)
: d_unary_fun_(d_unary_fun),
binary_fun_(binary_fun),
d_binary_fun_(d_binary_fun) {}
inline HOSTDEVICE T operator()(T x, T y, T out, T dout) {
T base;
if (Recomputation) {
base = dout * d_unary_fun_(binary_fun_(x, y));
} else {
base = dout * d_unary_fun_(binary_fun_(x, y), out);
}
return base * d_binary_fun_.Dy(x, y);
}
inline HOSTDEVICE T operator()(T x, T y, T intermediate_out, T out, T dout) {
T base;
if (Recomputation) {
base = dout * d_unary_fun_(intermediate_out);
} else {
base = dout * d_unary_fun_(intermediate_out, out);
}
return base * d_binary_fun_.Dy(x, y);
}
private:
DUnaryFun d_unary_fun_;
BinaryFun binary_fun_;
DBinaryFun d_binary_fun_;
};
} // namespace math
} // namespace operators
} // namespace paddle
...@@ -18,6 +18,19 @@ namespace paddle { ...@@ -18,6 +18,19 @@ namespace paddle {
namespace operators { namespace operators {
namespace math { namespace math {
// MulFunctor
template <typename T>
struct MulFunctor {
// out = x * y;
inline HOSTDEVICE T operator()(T x, T y) { return x * y; }
};
template <typename T>
struct MulGradFunctor {
inline HOSTDEVICE T Dx(T x, T y) { return y; }
inline HOSTDEVICE T Dy(T x, T y) { return x; }
};
// AddFunctor // AddFunctor
template <typename T> template <typename T>
struct AddFunctor { struct AddFunctor {
...@@ -27,9 +40,8 @@ struct AddFunctor { ...@@ -27,9 +40,8 @@ struct AddFunctor {
template <typename T> template <typename T>
struct AddGradFunctor { struct AddGradFunctor {
inline HOSTDEVICE T operator()(T x, T y) { return 1; } inline HOSTDEVICE T Dx(T x, T y) { return 1; }
inline HOSTDEVICE T Dy(T x, T y) { return 1; }
inline HOSTDEVICE T operator()(T x, T y, T out) const { return 1; }
}; };
template <typename T> template <typename T>
......
...@@ -47,7 +47,8 @@ def get_numeric_gradient(place, ...@@ -47,7 +47,8 @@ def get_numeric_gradient(place,
input_to_check, input_to_check,
output_names, output_names,
delta=0.005, delta=0.005,
in_place=False): in_place=False,
sum_outputs=None):
# FIXME: change this method by compile time concepts # FIXME: change this method by compile time concepts
set_input(scope, op, inputs, place) set_input(scope, op, inputs, place)
...@@ -58,9 +59,11 @@ def get_numeric_gradient(place, ...@@ -58,9 +59,11 @@ def get_numeric_gradient(place,
sum = [] sum = []
op.run(scope, place) op.run(scope, place)
for output_name in output_names: for output_name in output_names:
if sum_outputs and output_name not in sum_outputs:
continue
sum.append( sum.append(
np.array(scope.find_var(output_name).get_tensor()).mean()) np.array(scope.find_var(output_name).get_tensor()).mean())
return np.array(sum).mean() return np.array(sum).sum() / len(output_names)
tensor_to_check = scope.find_var(input_to_check).get_tensor() tensor_to_check = scope.find_var(input_to_check).get_tensor()
tensor_size = product(tensor_to_check.shape()) tensor_size = product(tensor_to_check.shape())
...@@ -396,13 +399,14 @@ class OpTest(unittest.TestCase): ...@@ -396,13 +399,14 @@ class OpTest(unittest.TestCase):
numeric_grad_delta=0.005, numeric_grad_delta=0.005,
in_place=False, in_place=False,
max_relative_error=0.005, max_relative_error=0.005,
user_defined_grads=None): user_defined_grads=None,
sum_outputs=None):
places = self._get_places() places = self._get_places()
for place in places: for place in places:
self.check_grad_with_place(place, inputs_to_check, output_names, self.check_grad_with_place(place, inputs_to_check, output_names,
no_grad_set, numeric_grad_delta, no_grad_set, numeric_grad_delta,
in_place, max_relative_error, in_place, max_relative_error,
user_defined_grads) user_defined_grads, sum_outputs)
def check_grad_with_place(self, def check_grad_with_place(self,
place, place,
...@@ -412,7 +416,8 @@ class OpTest(unittest.TestCase): ...@@ -412,7 +416,8 @@ class OpTest(unittest.TestCase):
numeric_grad_delta=0.005, numeric_grad_delta=0.005,
in_place=False, in_place=False,
max_relative_error=0.005, max_relative_error=0.005,
user_defined_grads=None): user_defined_grads=None,
sum_outputs=None):
self.scope = core.Scope() self.scope = core.Scope()
op_inputs = self.inputs if hasattr(self, "inputs") else dict() op_inputs = self.inputs if hasattr(self, "inputs") else dict()
op_outputs = self.outputs if hasattr(self, "outputs") else dict() op_outputs = self.outputs if hasattr(self, "outputs") else dict()
...@@ -435,7 +440,8 @@ class OpTest(unittest.TestCase): ...@@ -435,7 +440,8 @@ class OpTest(unittest.TestCase):
input_to_check, input_to_check,
output_names, output_names,
delta=numeric_grad_delta, delta=numeric_grad_delta,
in_place=in_place) for input_to_check in inputs_to_check in_place=in_place,
sum_outputs=sum_outputs) for input_to_check in inputs_to_check
] ]
analytic_grads = self._get_gradient(inputs_to_check, place, analytic_grads = self._get_gradient(inputs_to_check, place,
output_names, no_grad_set) output_names, no_grad_set)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册