提交 5aacd64b 编写于 作者: F fengjiayi

Follow comments

上级 0c13660a
......@@ -87,11 +87,11 @@ template <typename T, int MajorType = Eigen::RowMajor,
struct EigenVector : public EigenTensor<T, 1, MajorType, IndexType> {
// Flatten reshapes a Tensor into an EigenVector.
static typename EigenVector::Type Flatten(Tensor& tensor) {
return EigenVector::From(tensor, {static_cast<int>(product(tensor.dims_))});
return EigenVector::From(tensor, {product(tensor.dims_)});
}
static typename EigenVector::ConstType Flatten(const Tensor& tensor) {
return EigenVector::From(tensor, {static_cast<int>(product(tensor.dims_))});
return EigenVector::From(tensor, {product(tensor.dims_)});
}
};
......
......@@ -110,8 +110,7 @@ TEST(Eigen, Matrix) {
TEST(Eigen, MatrixReshape) {
Tensor t;
float* p =
t.mutable_data<float>(make_ddim({2, 3, 6, 4}), platform::CPUPlace());
float* p = t.mutable_data<float>({2, 3, 6, 4}, platform::CPUPlace());
for (int i = 0; i < 2 * 3 * 6 * 4; ++i) {
p[i] = static_cast<float>(i);
}
......
......@@ -267,7 +267,7 @@ TEST(Tensor, ReshapeToMatrix) {
using namespace paddle::framework;
using namespace paddle::platform;
Tensor src;
int* src_ptr = src.mutable_data<int>(make_ddim({2, 3, 4, 9}), CPUPlace());
int* src_ptr = src.mutable_data<int>({2, 3, 4, 9}, CPUPlace());
for (int i = 0; i < 2 * 3 * 4 * 9; ++i) {
src_ptr[i] = i;
}
......
......@@ -27,8 +27,8 @@ class MulOp : public framework::OperatorWithKernel {
void InferShape(const framework::InferShapeContext &ctx) const override {
auto x_dims = ctx.Input<Tensor>("X")->dims();
auto y_dims = ctx.Input<Tensor>("Y")->dims();
int x_num_col_dims = GetAttr<int>("x_num_col_dims");
int y_num_col_dims = GetAttr<int>("y_num_col_dims");
int x_num_col_dims = Attr<int>("x_num_col_dims");
int y_num_col_dims = Attr<int>("y_num_col_dims");
PADDLE_ENFORCE(x_dims.size() > x_num_col_dims,
"The rank of input tensor X(%s) should be larger than "
......@@ -58,19 +58,19 @@ class MulOpMaker : public framework::OpProtoAndCheckerMaker {
AddOutput("Out", "The output of mul op");
AddAttr<int>(
"x_num_col_dims",
"mul_op can take tensors with more than two dimensions as input `X`, "
"in that case, tensors will be reshaped to a matrix. The matrix's "
"first dimension(column length) will be the product of tensor's last "
"`num_col_dims` dimensions, and the matrix's second dimension(row "
"length) will be the product of tensor's first `rank - num_col_dims` "
"dimensions.")
R"DOC(mul_op can take tensors with more than two dimensions as input `X`,
in that case, tensors will be reshaped to a matrix. The matrix's first
dimension(column length) will be the product of tensor's last
`num_col_dims` dimensions, and the matrix's second dimension(row length)
will be the product of tensor's first `rank - num_col_dims` dimensions.
)DOC")
.SetDefault(1)
.EqualLargerThan(1);
AddAttr<int>(
"y_num_col_dims",
"mul_op can take tensors with more than two dimensions as input `Y`, "
"in that case, tensors will be reshaped to a matrix. Just like input "
"`X`.")
R"DOC(mul_op can take tensors with more than two dimensions as input `Y`,
in that case, tensors will be reshaped to a matrix. Just like input `X`.
)DOC")
.SetDefault(1)
.EqualLargerThan(1);
AddComment(R"DOC(
......@@ -98,9 +98,9 @@ class MulOpGrad : public framework::OperatorWithKernel {
auto *y_grad = ctx.Output<Tensor>(framework::GradVarName("Y"));
auto x_mat_dims =
framework::flatten_to_2d(x_dims, GetAttr<int>("x_num_col_dims"));
framework::flatten_to_2d(x_dims, Attr<int>("x_num_col_dims"));
auto y_mat_dims =
framework::flatten_to_2d(y_dims, GetAttr<int>("y_num_col_dims"));
framework::flatten_to_2d(y_dims, Attr<int>("y_num_col_dims"));
PADDLE_ENFORCE_EQ(
x_mat_dims[0], out_dims[0],
......
......@@ -37,12 +37,12 @@ class MulKernel : public framework::OpKernel {
const Tensor x_matrix =
x->dims().size() > 2
? framework::ReshapeToMatrix<T>(
*x, context.template GetAttr<int>("x_num_col_dims"))
*x, context.template Attr<int>("x_num_col_dims"))
: *x;
const Tensor y_matrix =
y->dims().size() > 2
? framework::ReshapeToMatrix<T>(
*y, context.template GetAttr<int>("y_num_col_dims"))
*y, context.template Attr<int>("y_num_col_dims"))
: *y;
z->mutable_data<T>(context.GetPlace());
......@@ -57,8 +57,8 @@ template <typename Place, typename T>
class MulGradKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
int x_num_col_dims = ctx.template GetAttr<int>("x_num_col_dims");
int y_num_col_dims = ctx.template GetAttr<int>("y_num_col_dims");
int x_num_col_dims = ctx.template Attr<int>("x_num_col_dims");
int y_num_col_dims = ctx.template Attr<int>("y_num_col_dims");
const Tensor* x = ctx.Input<Tensor>("X");
const Tensor* y = ctx.Input<Tensor>("Y");
const Tensor x_matrix =
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册