提交 f2a66ffa 编写于 作者: F fengjiayi

Follow comments

上级 256d6a33
...@@ -44,7 +44,7 @@ class LargerThanChecker { ...@@ -44,7 +44,7 @@ class LargerThanChecker {
public: public:
explicit LargerThanChecker(T lower_bound) : lower_bound_(lower_bound) {} explicit LargerThanChecker(T lower_bound) : lower_bound_(lower_bound) {}
void operator()(T& value) const { void operator()(T& value) const {
PADDLE_ENFORCE(value > lower_bound_, "larger_than check fail"); PADDLE_ENFORCE(value > lower_bound_, "larger_than check fails.");
} }
private: private:
...@@ -56,7 +56,7 @@ class EqualLargerThanChecker { ...@@ -56,7 +56,7 @@ class EqualLargerThanChecker {
public: public:
explicit EqualLargerThanChecker(T lower_bound) : lower_bound_(lower_bound) {} explicit EqualLargerThanChecker(T lower_bound) : lower_bound_(lower_bound) {}
void operator()(T& value) const { void operator()(T& value) const {
PADDLE_ENFORCE(value >= lower_bound_, "equal_larger_than check fail"); PADDLE_ENFORCE(value >= lower_bound_, "equal_larger_than check fails.");
} }
private: private:
......
...@@ -284,11 +284,13 @@ DDim::DDim(std::initializer_list<int> init_list) { ...@@ -284,11 +284,13 @@ DDim::DDim(std::initializer_list<int> init_list) {
*this = make_ddim(init_list); *this = make_ddim(init_list);
} }
DDim flatten_to_2d(const DDim& src, int num_row_dims) { // Reshape a tensor to a matrix. The matrix's first dimension(column length)
// will be the product of tensor's first `num_col_dims` dimensions
DDim flatten_to_2d(const DDim& src, int num_col_dims) {
int rank = src.size(); int rank = src.size();
return make_ddim( return make_ddim(
{static_cast<int>(product(slice_ddim(src, 0, rank - num_row_dims))), {static_cast<int>(product(slice_ddim(src, 0, num_col_dims))),
static_cast<int>(product(slice_ddim(src, rank - num_row_dims, rank)))}); static_cast<int>(product(slice_ddim(src, num_col_dims, rank)))});
} }
DDim flatten_to_1d(const DDim& src) { DDim flatten_to_1d(const DDim& src) {
......
...@@ -115,7 +115,7 @@ int arity(const DDim& ddim); ...@@ -115,7 +115,7 @@ int arity(const DDim& ddim);
std::ostream& operator<<(std::ostream&, const DDim&); std::ostream& operator<<(std::ostream&, const DDim&);
DDim flatten_to_2d(const DDim& src, int num_row_dims); DDim flatten_to_2d(const DDim& src, int num_col_dims);
DDim flatten_to_1d(const DDim& src); DDim flatten_to_1d(const DDim& src);
......
...@@ -64,21 +64,21 @@ struct EigenTensor { ...@@ -64,21 +64,21 @@ struct EigenTensor {
template <typename T, int MajorType = Eigen::RowMajor, template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex> typename IndexType = Eigen::DenseIndex>
struct EigenMatrix : public EigenTensor<T, 2, MajorType, IndexType> { struct EigenMatrix : public EigenTensor<T, 2, MajorType, IndexType> {
static typename EigenMatrix::Type Reshape(Tensor& tensor, int num_row_dims) { static typename EigenMatrix::Type Reshape(Tensor& tensor, int num_col_dims) {
int rank = tensor.dims_.size(); int rank = tensor.dims_.size();
PADDLE_ENFORCE(num_row_dims > 0 && num_row_dims < rank, PADDLE_ENFORCE(num_col_dims > 0 && num_col_dims < rank,
"`num_row_dims` must be between (0, rank_of_tensor)."); "`num_col_dims` must be between (0, rank_of_tensor).");
return EigenMatrix::From(tensor, return EigenMatrix::From(tensor,
flatten_to_2d(tensor.dims(), num_row_dims)); flatten_to_2d(tensor.dims(), num_col_dims));
} }
static typename EigenMatrix::ConstType Reshape(const Tensor& tensor, static typename EigenMatrix::ConstType Reshape(const Tensor& tensor,
int num_row_dims) { int num_col_dims) {
int rank = tensor.dims_.size(); int rank = tensor.dims_.size();
PADDLE_ENFORCE(num_row_dims > 0 && num_row_dims < rank, PADDLE_ENFORCE(num_col_dims > 0 && num_col_dims < rank,
"`num_row_dims` must be between (0, rank_of_tensor)."); "`num_col_dims` must be between (0, rank_of_tensor).");
return EigenMatrix::From(tensor, return EigenMatrix::From(tensor,
flatten_to_2d(tensor.dims(), num_row_dims)); flatten_to_2d(tensor.dims(), num_col_dims));
} }
}; };
......
...@@ -149,10 +149,10 @@ inline Tensor& Tensor::Resize(const DDim& dims) { ...@@ -149,10 +149,10 @@ inline Tensor& Tensor::Resize(const DDim& dims) {
inline const DDim& Tensor::dims() const { return dims_; } inline const DDim& Tensor::dims() const { return dims_; }
template <typename T> template <typename T>
inline Tensor FlattenToMatrix(const Tensor& src, int num_row_dims) { inline Tensor ReshapeToMatrix(const Tensor& src, int num_col_dims) {
Tensor res; Tensor res;
res.ShareDataWith<T>(src); res.ShareDataWith<T>(src);
res.Resize(flatten_to_2d(src.dims(), num_row_dims)); res.Resize(flatten_to_2d(src.dims(), num_col_dims));
return res; return res;
} }
......
...@@ -263,7 +263,7 @@ TEST(Tensor, CopyFrom) { ...@@ -263,7 +263,7 @@ TEST(Tensor, CopyFrom) {
#endif #endif
} }
TEST(Tensor, FlattenToMatrix) { TEST(Tensor, ReshapeToMatrix) {
using namespace paddle::framework; using namespace paddle::framework;
using namespace paddle::platform; using namespace paddle::platform;
Tensor src; Tensor src;
...@@ -271,7 +271,7 @@ TEST(Tensor, FlattenToMatrix) { ...@@ -271,7 +271,7 @@ TEST(Tensor, FlattenToMatrix) {
for (int i = 0; i < 2 * 3 * 4 * 9; ++i) { for (int i = 0; i < 2 * 3 * 4 * 9; ++i) {
src_ptr[i] = i; src_ptr[i] = i;
} }
Tensor res = FlattenToMatrix<int>(src, 2); Tensor res = ReshapeToMatrix<int>(src, 2);
ASSERT_EQ(res.dims()[0], 2 * 3); ASSERT_EQ(res.dims()[0], 2 * 3);
ASSERT_EQ(res.dims()[1], 4 * 9); ASSERT_EQ(res.dims()[1], 4 * 9);
} }
\ No newline at end of file
...@@ -27,20 +27,20 @@ class MulOp : public framework::OperatorWithKernel { ...@@ -27,20 +27,20 @@ class MulOp : public framework::OperatorWithKernel {
void InferShape(const framework::InferShapeContext &ctx) const override { void InferShape(const framework::InferShapeContext &ctx) const override {
auto x_dims = ctx.Input<Tensor>("X")->dims(); auto x_dims = ctx.Input<Tensor>("X")->dims();
auto y_dims = ctx.Input<Tensor>("Y")->dims(); auto y_dims = ctx.Input<Tensor>("Y")->dims();
int x_num_row_dims = GetAttr<int>("x_num_row_dims"); int x_num_col_dims = GetAttr<int>("x_num_col_dims");
int y_num_row_dims = GetAttr<int>("y_num_row_dims"); int y_num_col_dims = GetAttr<int>("y_num_col_dims");
PADDLE_ENFORCE(x_dims.size() > x_num_row_dims, PADDLE_ENFORCE(x_dims.size() > x_num_col_dims,
"The rank of input tensor X(%s) should be larger than " "The rank of input tensor X(%s) should be larger than "
"`mul_op`'s `x_num_row_dims`.", "`mul_op`'s `x_num_col_dims`.",
ctx.op().Input("X")); ctx.op().Input("X"));
PADDLE_ENFORCE(y_dims.size() > y_num_row_dims, PADDLE_ENFORCE(y_dims.size() > y_num_col_dims,
"The rank of input tensor Y(%s) should be larger than " "The rank of input tensor Y(%s) should be larger than "
"`mul_op`'s `y_num_row_dims`.", "`mul_op`'s `y_num_col_dims`.",
ctx.op().Input("Y")); ctx.op().Input("Y"));
auto x_mat_dims = framework::flatten_to_2d(x_dims, x_num_row_dims); auto x_mat_dims = framework::flatten_to_2d(x_dims, x_num_col_dims);
auto y_mat_dims = framework::flatten_to_2d(y_dims, y_num_row_dims); auto y_mat_dims = framework::flatten_to_2d(y_dims, y_num_col_dims);
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
x_mat_dims[1], y_mat_dims[0], x_mat_dims[1], y_mat_dims[0],
...@@ -57,19 +57,19 @@ class MulOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -57,19 +57,19 @@ class MulOpMaker : public framework::OpProtoAndCheckerMaker {
AddInput("Y", "The second input of mul op"); AddInput("Y", "The second input of mul op");
AddOutput("Out", "The output of mul op"); AddOutput("Out", "The output of mul op");
AddAttr<int>( AddAttr<int>(
"x_num_row_dims", "x_num_col_dims",
"mul_op can take tensors with more than two dimensions as input `X`, " "mul_op can take tensors with more than two dimensions as input `X`, "
"in that case, tensors will be flattened to a matrix. The matrix's " "in that case, tensors will be reshaped to a matrix. The matrix's "
"second dimension(row length) will be the product of tensor's last " "first dimension(column length) will be the product of tensor's last "
"`num_row_dims` dimensions, and the matrix's first dimension(column " "`num_col_dims` dimensions, and the matrix's second dimension(row "
"length) will be the product of tensor's first `rank - num_row_dims` " "length) will be the product of tensor's first `rank - num_col_dims` "
"dimensions.") "dimensions.")
.SetDefault(1) .SetDefault(1)
.EqualLargerThan(1); .EqualLargerThan(1);
AddAttr<int>( AddAttr<int>(
"y_num_row_dims", "y_num_col_dims",
"mul_op can take tensors with more than two dimensions as input `Y`, " "mul_op can take tensors with more than two dimensions as input `Y`, "
"in that case, tensors will be flattened to a matrix. Just like input " "in that case, tensors will be reshaped to a matrix. Just like input "
"`X`.") "`X`.")
.SetDefault(1) .SetDefault(1)
.EqualLargerThan(1); .EqualLargerThan(1);
...@@ -98,9 +98,9 @@ class MulOpGrad : public framework::OperatorWithKernel { ...@@ -98,9 +98,9 @@ class MulOpGrad : public framework::OperatorWithKernel {
auto *y_grad = ctx.Output<Tensor>(framework::GradVarName("Y")); auto *y_grad = ctx.Output<Tensor>(framework::GradVarName("Y"));
auto x_mat_dims = auto x_mat_dims =
framework::flatten_to_2d(x_dims, GetAttr<int>("x_num_row_dims")); framework::flatten_to_2d(x_dims, GetAttr<int>("x_num_col_dims"));
auto y_mat_dims = auto y_mat_dims =
framework::flatten_to_2d(y_dims, GetAttr<int>("y_num_row_dims")); framework::flatten_to_2d(y_dims, GetAttr<int>("y_num_col_dims"));
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
x_mat_dims[0], out_dims[0], x_mat_dims[0], out_dims[0],
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. /* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. You may not use this file except in compliance with the License.
you may obtain a copy of the License at You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS, distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANy KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
...@@ -33,22 +33,22 @@ class MulKernel : public framework::OpKernel { ...@@ -33,22 +33,22 @@ class MulKernel : public framework::OpKernel {
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
const Tensor* x = context.Input<Tensor>("X"); const Tensor* x = context.Input<Tensor>("X");
const Tensor* y = context.Input<Tensor>("Y"); const Tensor* y = context.Input<Tensor>("Y");
Tensor* Z = context.Output<Tensor>("Out"); Tensor* z = context.Output<Tensor>("Out");
const Tensor x_matrix = const Tensor x_matrix =
x->dims().size() > 2 x->dims().size() > 2
? framework::FlattenToMatrix<T>( ? framework::ReshapeToMatrix<T>(
*x, context.template GetAttr<int>("x_num_row_dims")) *x, context.template GetAttr<int>("x_num_col_dims"))
: *x; : *x;
const Tensor y_matrix = const Tensor y_matrix =
y->dims().size() > 2 y->dims().size() > 2
? framework::FlattenToMatrix<T>( ? framework::ReshapeToMatrix<T>(
*y, context.template GetAttr<int>("y_num_row_dims")) *y, context.template GetAttr<int>("y_num_col_dims"))
: *y; : *y;
Z->mutable_data<T>(context.GetPlace()); z->mutable_data<T>(context.GetPlace());
auto* device_context = auto* device_context =
const_cast<platform::DeviceContext*>(context.device_context_); const_cast<platform::DeviceContext*>(context.device_context_);
math::matmul<Place, T>(x_matrix, false, y_matrix, false, 1, Z, 0, math::matmul<Place, T>(x_matrix, false, y_matrix, false, 1, z, 0,
device_context); device_context);
} }
}; };
...@@ -57,15 +57,15 @@ template <typename Place, typename T> ...@@ -57,15 +57,15 @@ template <typename Place, typename T>
class MulGradKernel : public framework::OpKernel { class MulGradKernel : public framework::OpKernel {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
int x_num_row_dims = ctx.template GetAttr<int>("x_num_row_dims"); int x_num_col_dims = ctx.template GetAttr<int>("x_num_col_dims");
int y_num_row_dims = ctx.template GetAttr<int>("y_num_row_dims"); int y_num_col_dims = ctx.template GetAttr<int>("y_num_col_dims");
const Tensor* x = ctx.Input<Tensor>("X"); const Tensor* x = ctx.Input<Tensor>("X");
const Tensor* y = ctx.Input<Tensor>("Y"); const Tensor* y = ctx.Input<Tensor>("Y");
const Tensor x_matrix = const Tensor x_matrix =
x->dims().size() > 2 ? framework::FlattenToMatrix<T>(*x, x_num_row_dims) x->dims().size() > 2 ? framework::ReshapeToMatrix<T>(*x, x_num_col_dims)
: *x; : *x;
const Tensor y_matrix = const Tensor y_matrix =
y->dims().size() > 2 ? framework::FlattenToMatrix<T>(*y, y_num_row_dims) y->dims().size() > 2 ? framework::ReshapeToMatrix<T>(*y, y_num_col_dims)
: *y; : *y;
const Tensor* dout = ctx.Input<Tensor>(framework::GradVarName("Out")); const Tensor* dout = ctx.Input<Tensor>(framework::GradVarName("Out"));
...@@ -75,8 +75,8 @@ class MulGradKernel : public framework::OpKernel { ...@@ -75,8 +75,8 @@ class MulGradKernel : public framework::OpKernel {
const_cast<platform::DeviceContext*>(ctx.device_context_); const_cast<platform::DeviceContext*>(ctx.device_context_);
if (dx) { if (dx) {
dx->mutable_data<T>(ctx.GetPlace()); dx->mutable_data<T>(ctx.GetPlace());
Tensor dx_matrix = dx->dims().size() > 2 ? framework::FlattenToMatrix<T>( Tensor dx_matrix = dx->dims().size() > 2 ? framework::ReshapeToMatrix<T>(
*dx, x_num_row_dims) *dx, x_num_col_dims)
: *dx; : *dx;
// dx = dout * y'. dx: M x K, dout : M x N, y : K x N // dx = dout * y'. dx: M x K, dout : M x N, y : K x N
math::matmul<Place, T>(*dout, false, y_matrix, true, 1, &dx_matrix, 0, math::matmul<Place, T>(*dout, false, y_matrix, true, 1, &dx_matrix, 0,
...@@ -84,8 +84,8 @@ class MulGradKernel : public framework::OpKernel { ...@@ -84,8 +84,8 @@ class MulGradKernel : public framework::OpKernel {
} }
if (dy) { if (dy) {
dy->mutable_data<T>(ctx.GetPlace()); dy->mutable_data<T>(ctx.GetPlace());
Tensor dy_matrix = dy->dims().size() > 2 ? framework::FlattenToMatrix<T>( Tensor dy_matrix = dy->dims().size() > 2 ? framework::ReshapeToMatrix<T>(
*dy, y_num_row_dims) *dy, y_num_col_dims)
: *dy; : *dy;
// dy = x' * dout. dy K x N, dout : M x N, x : M x K // dy = x' * dout. dy K x N, dout : M x N, x : M x K
math::matmul<Place, T>(x_matrix, true, *dout, false, 1, &dy_matrix, 0, math::matmul<Place, T>(x_matrix, true, *dout, false, 1, &dy_matrix, 0,
......
...@@ -31,11 +31,11 @@ class RowwiseAddOp : public framework::OperatorWithKernel { ...@@ -31,11 +31,11 @@ class RowwiseAddOp : public framework::OperatorWithKernel {
x_dims.size(), b_dims.size(), x_dims.size(), b_dims.size(),
"The rank of input `X` must be larger than the one of input `b`."); "The rank of input `X` must be larger than the one of input `b`.");
int num_row_dims = b_dims.size(); int num_col_dims = x_dims.size() - b_dims.size();
PADDLE_ENFORCE_EQ(framework::slice_ddim( PADDLE_ENFORCE_EQ(
x_dims, x_dims.size() - num_row_dims, x_dims.size()), framework::slice_ddim(x_dims, num_col_dims, x_dims.size()), b_dims,
b_dims, "The width of two operands must be same"); "The width of two operands must be same");
PADDLE_ENFORCE_EQ(ctx.OutputSize("Out"), 1, "The output size must be 1"); PADDLE_ENFORCE_EQ(ctx.OutputSize("Out"), 1, "The output size must be 1");
ctx.Output<Tensor>("Out")->Resize(x_dims); ctx.Output<Tensor>("Out")->Resize(x_dims);
} }
...@@ -72,10 +72,10 @@ class RowwiseAddGradOp : public framework::OperatorWithKernel { ...@@ -72,10 +72,10 @@ class RowwiseAddGradOp : public framework::OperatorWithKernel {
x_dims.size(), b_dims.size(), x_dims.size(), b_dims.size(),
"The rank of input `X` must be larger than the one of input `b`."); "The rank of input `X` must be larger than the one of input `b`.");
int num_row_dims = b_dims.size(); int num_col_dims = x_dims.size() - b_dims.size();
PADDLE_ENFORCE_EQ(framework::slice_ddim( PADDLE_ENFORCE_EQ(
x_dims, x_dims.size() - num_row_dims, x_dims.size()), framework::slice_ddim(x_dims, num_col_dims, x_dims.size()), b_dims,
b_dims, "The width of two operands must be same"); "The width of two operands must be same");
auto *dx = ctx.Output<Tensor>(framework::GradVarName("X")); auto *dx = ctx.Output<Tensor>(framework::GradVarName("X"));
auto *db = ctx.Output<Tensor>(framework::GradVarName("b")); auto *db = ctx.Output<Tensor>(framework::GradVarName("b"));
if (dx) dx->Resize(x_dims); if (dx) dx->Resize(x_dims);
......
...@@ -33,11 +33,12 @@ class RowwiseAddKernel : public framework::OpKernel { ...@@ -33,11 +33,12 @@ class RowwiseAddKernel : public framework::OpKernel {
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
auto out = context.Output<Tensor>("Out"); auto out = context.Output<Tensor>("Out");
out->mutable_data<T>(context.GetPlace()); out->mutable_data<T>(context.GetPlace());
int num_row_dims = context.Input<Tensor>("b")->dims().size(); int num_col_dims = context.Input<Tensor>("X")->dims().size() -
context.Input<Tensor>("b")->dims().size();
auto input = auto input =
EigenMatrix<T>::Reshape(*context.Input<Tensor>("X"), num_row_dims); EigenMatrix<T>::Reshape(*context.Input<Tensor>("X"), num_col_dims);
auto bias = EigenVector<T>::Flatten(*context.Input<Tensor>("b")); auto bias = EigenVector<T>::Flatten(*context.Input<Tensor>("b"));
auto output = EigenMatrix<T>::Reshape(*out, num_row_dims); auto output = EigenMatrix<T>::Reshape(*out, num_col_dims);
const int bias_size = bias.dimension(0); const int bias_size = bias.dimension(0);
const int rest_size = input.size() / bias_size; const int rest_size = input.size() / bias_size;
...@@ -55,14 +56,15 @@ class RowwiseAddGradKernel : public framework::OpKernel { ...@@ -55,14 +56,15 @@ class RowwiseAddGradKernel : public framework::OpKernel {
auto* dout = context.Input<Tensor>(framework::GradVarName("Out")); auto* dout = context.Input<Tensor>(framework::GradVarName("Out"));
auto* dx = context.Output<Tensor>(framework::GradVarName("X")); auto* dx = context.Output<Tensor>(framework::GradVarName("X"));
auto* db = context.Output<Tensor>(framework::GradVarName("b")); auto* db = context.Output<Tensor>(framework::GradVarName("b"));
int num_row_dims = context.Input<Tensor>("b")->dims().size(); int num_col_dims = context.Input<Tensor>("X")->dims().size() -
context.Input<Tensor>("b")->dims().size();
auto out_grad = EigenMatrix<T>::Reshape(*dout, num_row_dims); auto out_grad = EigenMatrix<T>::Reshape(*dout, num_col_dims);
auto place = context.GetEigenDevice<Place>(); auto place = context.GetEigenDevice<Place>();
if (dx) { if (dx) {
dx->mutable_data<T>(context.GetPlace()); dx->mutable_data<T>(context.GetPlace());
EigenMatrix<T>::Reshape(*dx, num_row_dims).device(place) = out_grad; EigenMatrix<T>::Reshape(*dx, num_col_dims).device(place) = out_grad;
} }
if (db) { if (db) {
......
...@@ -26,7 +26,7 @@ class TestMulOp2(unittest.TestCase): ...@@ -26,7 +26,7 @@ class TestMulOp2(unittest.TestCase):
'X': np.random.random((15, 4, 12, 10)).astype("float32"), 'X': np.random.random((15, 4, 12, 10)).astype("float32"),
'Y': np.random.random((4, 30, 8, 2, 9)).astype("float32") 'Y': np.random.random((4, 30, 8, 2, 9)).astype("float32")
} }
self.attrs = {'x_num_row_dims': 2, 'y_num_row_dims': 3} self.attrs = {'x_num_col_dims': 2, 'y_num_col_dims': 2}
self.outputs = { self.outputs = {
'Out': np.dot(self.inputs['X'].reshape(15 * 4, 12 * 10), 'Out': np.dot(self.inputs['X'].reshape(15 * 4, 12 * 10),
self.inputs['Y'].reshape(4 * 30, 8 * 2 * 9)) self.inputs['Y'].reshape(4 * 30, 8 * 2 * 9))
...@@ -69,7 +69,7 @@ class TestMulGradOp(GradientChecker): ...@@ -69,7 +69,7 @@ class TestMulGradOp(GradientChecker):
class TestMulGradTest2(GradientChecker): class TestMulGradTest2(GradientChecker):
def setUp(self): def setUp(self):
self.op = Operator( self.op = Operator(
"mul", X="X", Y="Y", Out="Out", x_num_row_dims=2, y_num_row_dims=3) "mul", X="X", Y="Y", Out="Out", x_num_col_dims=2, y_num_col_dims=2)
self.inputs = { self.inputs = {
"X": np.random.random((15, 4, 12, 10)).astype("float32"), "X": np.random.random((15, 4, 12, 10)).astype("float32"),
"Y": np.random.random((4, 30, 8, 2, 9)).astype("float32") "Y": np.random.random((4, 30, 8, 2, 9)).astype("float32")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册