未验证 提交 561f9013 编写于 作者: J Jiabin Yang 提交者: GitHub

【Prim】Support elementwise related VJP with primitives (#49784)

* support elementwise base func

* fix compiling error and add test

* remove additional param

* support vjp for div using comp

* remove additional change

* fix dy2st error with magic num

* fix dy magic num

* another magic

* another magic

* add more test

* fix windows problem

* another magic

* fix windows compile

* invoke ci

* add skip rename strategy

* support add vjp

* fix test_tanh

* support add with new axis cal

* fix resnet and some test

* add composite log

* support sub vjp
上级 dd827bbe
......@@ -148,7 +148,7 @@ cc_library(ops_extra_info SRCS ops_extra_info.cc DEPS attribute cudnn_workspace_
set(COMMON_OP_DEPS ${COMMON_OP_DEPS} selected_rows_functor selected_rows_utils lapack_function
lod_tensor maxouting unpooling pooling lod_rank_table context_project
sequence_pooling executor generator static_prim_api)
set(COMMON_OP_DEPS ${COMMON_OP_DEPS} dynload_warpctc)
set(COMMON_OP_DEPS ${COMMON_OP_DEPS} dynload_warpctc static_prim_api static_utils static_global_utils prim_utils)
set(COMMON_OP_DEPS ${COMMON_OP_DEPS} sequence_padding sequence_scale cos_sim_functor memory jit_kernel_helper concat_and_split cross_entropy softmax vol2col im2col sampler sample_prob tree2col)
set(COMMON_OP_DEPS ${COMMON_OP_DEPS} sequence2batch lstm_compute matrix_bit_code gru_compute activation_functions beam_search fc_functor matrix_inverse matrix_solve)
set(COMMON_OP_DEPS ${COMMON_OP_DEPS} box_wrapper ps_gpu_wrapper)
......@@ -216,7 +216,7 @@ endif()
set(GLOB_OP_LIB ${OP_LIBRARY} CACHE INTERNAL "Global OP library")
add_subdirectory(benchmark)
cc_test(op_debug_string_test SRCS op_debug_string_test.cc DEPS elementwise_add_op)
cc_test_old(op_debug_string_test SRCS op_debug_string_test.cc DEPS elementwise_add_op ${COMMON_OP_DEPS})
if (WITH_ASCEND_CL)
cc_test(transpose_op_npu_test SRCS transpose_op_npu_test.cc DEPS op_registry transpose_op scope device_context enforce executor)
endif()
......
......@@ -15,7 +15,9 @@ limitations under the License. */
#include <string>
#include "paddle/fluid/operators/elementwise/elementwise_op.h"
#include "paddle/fluid/prim/api/manual/backward/composite_backward_api.h"
#include "paddle/fluid/prim/utils/static/composite_grad_desc_maker.h"
#include "paddle/fluid/prim/utils/static/desc_tensor.h"
namespace paddle {
namespace framework {
class OpDesc;
......@@ -49,6 +51,29 @@ class ElementwiseAddOpMaker : public ElementwiseOpMaker {
}
};
class ElementwiseAddGradCompositeOpMaker
: public prim::GradCompositeOpMakerBase {
using prim::GradCompositeOpMakerBase::GradCompositeOpMakerBase;
public:
void Apply() override {
paddle::experimental::Tensor x = this->GetSingleForwardInput("X");
paddle::experimental::Tensor y = this->GetSingleForwardInput("Y");
paddle::experimental::Tensor out_grad = this->GetSingleOutputGrad("Out");
paddle::experimental::Tensor dx = this->GetSingleInputGrad("X");
auto dx_ptr = this->GetOutputPtr(&dx);
std::string dx_name = this->GetOutputName(dx);
paddle::experimental::Tensor dy = this->GetSingleInputGrad("Y");
auto dy_ptr = this->GetOutputPtr(&dy);
std::string dy_name = this->GetOutputName(dy);
int axis = static_cast<int>(this->Attr<int>("axis"));
VLOG(3) << "Runing add_grad composite func";
prim::add_grad<prim::DescTensor>(x, y, out_grad, axis, dx_ptr, dy_ptr);
this->RecoverOutputName(dx, dx_name);
this->RecoverOutputName(dy, dy_name);
}
};
template <typename T>
class ElementwiseAddDoubleGradMaker : public framework::SingleGradOpMaker<T> {
public:
......@@ -91,9 +116,17 @@ class ElementwiseAddTripleGradMaker : public framework::SingleGradOpMaker<T> {
} // namespace paddle
REGISTER_ELEMWISE_GRAD_MAKER(elementwise_add, Add);
REGISTER_ELEMWISE_EXPLICIT_OP_WITHOUT_GRAD(elementwise_add, Add);
REGISTER_OPERATOR(elementwise_add,
::paddle::operators::ElementwiseOp,
::paddle::operators::ElementwiseAddOpMaker,
::paddle::operators::ElementwiseOpInferVarType,
elementwise_addGradMaker<::paddle::framework::OpDesc>,
elementwise_addGradMaker<::paddle::imperative::OpBase>,
::paddle::operators::ElementwiseAddGradCompositeOpMaker,
::paddle::operators::ElementwiseOpInplaceInferer);
namespace ops = paddle::operators;
REGISTER_OPERATOR(
elementwise_add_grad,
ops::ElementwiseOpGrad,
......
......@@ -19,7 +19,9 @@ limitations under the License. */
#include "paddle/fluid/operators/elementwise/elementwise_op.h"
#include "paddle/fluid/platform/complex.h"
#include "paddle/fluid/prim/api/manual/backward/composite_backward_api.h"
#include "paddle/fluid/prim/utils/static/composite_grad_desc_maker.h"
#include "paddle/fluid/prim/utils/static/desc_tensor.h"
namespace paddle {
namespace operators {
......@@ -65,6 +67,31 @@ class ElementwiseDivGradOpMaker : public framework::SingleGradOpMaker<T> {
}
};
class ElementwiseDivGradCompositeOpMaker
: public prim::GradCompositeOpMakerBase {
using prim::GradCompositeOpMakerBase::GradCompositeOpMakerBase;
public:
void Apply() override {
paddle::experimental::Tensor x = this->GetSingleForwardInput("X");
paddle::experimental::Tensor y = this->GetSingleForwardInput("Y");
paddle::experimental::Tensor out = this->GetSingleForwardOutput("Out");
paddle::experimental::Tensor out_grad = this->GetSingleOutputGrad("Out");
paddle::experimental::Tensor dx = this->GetSingleInputGrad("X");
auto dx_ptr = this->GetOutputPtr(&dx);
std::string dx_name = this->GetOutputName(dx);
paddle::experimental::Tensor dy = this->GetSingleInputGrad("Y");
auto dy_ptr = this->GetOutputPtr(&dy);
std::string dy_name = this->GetOutputName(dy);
int axis = static_cast<int>(this->Attr<int>("axis"));
VLOG(3) << "Runing div_grad composite func";
prim::divide_grad<prim::DescTensor>(
x, y, out, out_grad, axis, dx_ptr, dy_ptr);
this->RecoverOutputName(dx, dx_name);
this->RecoverOutputName(dy, dy_name);
}
};
template <typename T>
class ElementwiseDivDoubleGradMaker : public framework::SingleGradOpMaker<T> {
public:
......@@ -96,6 +123,7 @@ REGISTER_OPERATOR(elementwise_div,
ops::ElementwiseOp,
ops::ElementwiseDivOpMaker,
ops::ElementwiseOpInferVarType,
ops::ElementwiseDivGradCompositeOpMaker,
ops::ElementwiseDivGradOpMaker<paddle::framework::OpDesc>,
ops::ElementwiseDivGradOpMaker<paddle::imperative::OpBase>);
......
......@@ -15,7 +15,9 @@ limitations under the License. */
#include <string>
#include "paddle/fluid/operators/elementwise/elementwise_op.h"
#include "paddle/fluid/prim/api/manual/backward/composite_backward_api.h"
#include "paddle/fluid/prim/utils/static/composite_grad_desc_maker.h"
#include "paddle/fluid/prim/utils/static/desc_tensor.h"
namespace paddle {
namespace framework {
class OpDesc;
......@@ -52,6 +54,29 @@ class ElementwiseSubOpMaker : public ElementwiseOpMaker {
}
};
class ElementwiseSubGradCompositeOpMaker
: public prim::GradCompositeOpMakerBase {
using prim::GradCompositeOpMakerBase::GradCompositeOpMakerBase;
public:
void Apply() override {
paddle::experimental::Tensor x = this->GetSingleForwardInput("X");
paddle::experimental::Tensor y = this->GetSingleForwardInput("Y");
paddle::experimental::Tensor out_grad = this->GetSingleOutputGrad("Out");
paddle::experimental::Tensor dx = this->GetSingleInputGrad("X");
auto dx_ptr = this->GetOutputPtr(&dx);
std::string dx_name = this->GetOutputName(dx);
paddle::experimental::Tensor dy = this->GetSingleInputGrad("Y");
auto dy_ptr = this->GetOutputPtr(&dy);
std::string dy_name = this->GetOutputName(dy);
int axis = static_cast<int>(this->Attr<int>("axis"));
VLOG(3) << "Runing sub_grad composite func";
prim::subtract_grad<prim::DescTensor>(x, y, out_grad, axis, dx_ptr, dy_ptr);
this->RecoverOutputName(dx, dx_name);
this->RecoverOutputName(dy, dy_name);
}
};
template <typename T>
class ElementwiseSubDoubleGradMaker : public framework::SingleGradOpMaker<T> {
public:
......@@ -84,6 +109,7 @@ REGISTER_OPERATOR(elementwise_sub,
::paddle::operators::ElementwiseOpInferVarType,
elementwise_subGradMaker<::paddle::framework::OpDesc>,
elementwise_subGradMaker<::paddle::imperative::OpBase>,
::paddle::operators::ElementwiseSubGradCompositeOpMaker,
::paddle::operators::ElementwiseOpInplaceInferer);
REGISTER_OPERATOR(
......
......@@ -635,6 +635,7 @@ class {{op_name | to_composite_grad_opmaker_name}} : public prim::GradCompositeO
{%- endmacro %}
{% macro call_composite_backward_api(composite_op_dict) %}
VLOG(3) << "Runing {{composite_op_dict["composite"]["func_name"]}} composite func";
prim::{{composite_op_dict["composite"]["func_name"]}}<prim::DescTensor>({{composite_op_dict["composite"]["func_args"]}});
{%- endmacro %}
......
......@@ -27,5 +27,114 @@ void tanh_grad(const Tensor& out, const Tensor& grad_out, Tensor* grad_x) {
auto grad_x_tmp = multiply<T>(grad_out, tmp);
grad_x->set_impl(grad_x_tmp.impl());
}
template <typename T>
void subtract_grad(const Tensor& x,
const Tensor& y,
const Tensor& out_grad,
int axis,
Tensor* dx,
Tensor* dy) {
if (dy) {
auto scale_out_grad = scale<T>(out_grad, -1.0, 0.0, true);
if (phi::product(x.dims()) > phi::product(y.dims())) {
// Maybe need reduce here
phi::DDim reduce_dim = get_reduce_dims(x.dims(), y.dims());
auto dy_reduce_res =
sum<T>(scale_out_grad, phi::vectorize(reduce_dim), y.dtype(), false);
auto dy_tmp = reshape<T>(dy_reduce_res, phi::vectorize(y.dims()));
dy->set_impl(dy_tmp.impl());
} else {
by_pass<T>(scale_out_grad, dy);
}
}
if (dx) {
if (phi::product(y.dims()) > phi::product(x.dims())) {
// Maybe need reduce here
auto reduce_dim = get_reduce_dims(y.dims(), x.dims());
auto dx_reduce_res =
sum<T>(out_grad, phi::vectorize(reduce_dim), x.dtype(), false);
auto dx_tmp = reshape<T>(dx_reduce_res, phi::vectorize(x.dims()));
dx->set_impl(dx_tmp.impl());
} else {
by_pass<T>(out_grad, dx);
}
}
}
template <typename T>
void add_grad(const Tensor& x,
const Tensor& y,
const Tensor& out_grad,
int axis,
Tensor* dx,
Tensor* dy) {
if (dy) {
if (phi::product(x.dims()) > phi::product(y.dims())) {
// Maybe need reduce here
phi::DDim reduce_dim = get_reduce_dims(x.dims(), y.dims());
auto dy_reduce_res =
sum<T>(out_grad, phi::vectorize(reduce_dim), y.dtype(), false);
auto dy_tmp = reshape<T>(dy_reduce_res, phi::vectorize(y.dims()));
dy->set_impl(dy_tmp.impl());
} else {
by_pass<T>(out_grad, dy);
}
}
if (dx) {
if (phi::product(y.dims()) > phi::product(x.dims())) {
// Maybe need reduce here
auto reduce_dim = get_reduce_dims(y.dims(), x.dims());
auto dx_reduce_res =
sum<T>(out_grad, phi::vectorize(reduce_dim), x.dtype(), false);
auto dx_tmp = reshape<T>(dx_reduce_res, phi::vectorize(x.dims()));
dx->set_impl(dx_tmp.impl());
} else {
by_pass<T>(out_grad, dx);
}
}
}
template <typename T>
void divide_grad(const Tensor& x,
const Tensor& y,
const Tensor& out,
const Tensor& out_grad,
int axis,
Tensor* dx,
Tensor* dy) {
if (dy) {
// dy = -(x/y^2) * dout
auto tmp0 = pow<T>(y, 2.0);
auto tmp1 = divide<T>(x, tmp0);
auto tmp2 = scale<T>(tmp1, -1.0, 0.0, true);
auto dy_res = multiply<T>(tmp2, out_grad);
if (phi::product(x.dims()) > phi::product(y.dims())) {
// Maybe need reduce here
phi::DDim reduce_dim = get_reduce_dims(x.dims(), y.dims());
auto dy_reduce_res =
sum<T>(dy_res, phi::vectorize(reduce_dim), y.dtype(), false);
auto dy_tmp = reshape<T>(dy_reduce_res, phi::vectorize(y.dims()));
dy->set_impl(dy_tmp.impl());
} else {
dy->set_impl(dy_res.impl());
}
} // indicate we will compute dy
if (dx) {
// dx = (1/y) * dout
auto one_tensor = full<T>(phi::vectorize(y.dims()), 1.0);
auto tmp0 = divide<T>(one_tensor, y);
auto dx_res = multiply<T>(tmp0, out_grad);
if (phi::product(y.dims()) > phi::product(x.dims())) {
// Maybe need reduce here
auto reduce_dim = get_reduce_dims(y.dims(), x.dims());
auto dx_reduce_res =
sum<T>(dx_res, phi::vectorize(reduce_dim), x.dtype(), false);
auto dx_tmp = reshape<T>(dx_reduce_res, phi::vectorize(x.dims()));
dx->set_impl(dx_tmp.impl());
} else {
dx->set_impl(dx_res.impl());
}
} // indicate we will compute dx
}
} // namespace prim
} // namespace paddle
......@@ -15,7 +15,7 @@
#include "paddle/fluid/eager/api/all.h"
#include "paddle/fluid/eager/api/generated/eager_generated/forwards/dygraph_functions.h"
#include "paddle/fluid/prim/api/manual/prim_api/prim_api.h"
#include "paddle/phi/capi/include/wrapper_base.h"
namespace paddle {
namespace prim {
template <>
......@@ -35,5 +35,27 @@ template <>
Tensor multiply<Tensor>(const Tensor& x, const Tensor& y) {
return ::multiply_ad_func(x, y);
}
template <>
Tensor divide<Tensor>(const Tensor& x, const Tensor& y) {
return ::divide_ad_func(x, y);
}
template <>
Tensor full<Tensor>(paddle::experimental::IntArray shape,
paddle::experimental::Scalar value,
paddle::experimental::DataType dtype,
paddle::platform::Place place) {
return ::full_ad_func(shape, value, dtype, place);
}
template <>
Tensor sum<Tensor>(Tensor x, IntArray axis, DataType dtype, bool keepdim) {
return ::sum_ad_func(x, axis, dtype, keepdim);
}
template <>
Tensor reshape<Tensor>(Tensor x, IntArray shape) {
return ::reshape_ad_func(x, shape);
}
} // namespace prim
} // namespace paddle
......@@ -13,12 +13,14 @@
// limitations under the License.
#pragma once
#include "paddle/phi/common/int_array.h"
#include "paddle/phi/common/scalar.h"
#include "paddle/utils/optional.h"
namespace paddle {
namespace prim {
using Tensor = paddle::experimental::Tensor;
using IntArray = paddle::experimental::IntArray;
using Scalar = paddle::experimental::Scalar;
template <typename T>
Tensor pow(const Tensor& x, const paddle::experimental::Scalar& y);
......@@ -31,5 +33,22 @@ Tensor scale(const Tensor& X,
template <typename T>
Tensor multiply(const Tensor& x, const Tensor& y);
template <typename T>
Tensor divide(const Tensor& x, const Tensor& y);
template <typename T>
Tensor full(IntArray shape,
Scalar value,
DataType dtype = DataType::FLOAT32,
Place place = CPUPlace());
template <typename T>
Tensor sum(Tensor x,
IntArray axis = {},
DataType dtype = DataType::UNDEFINED,
bool keepdim = false);
template <typename T>
Tensor reshape(Tensor x, IntArray shape);
} // namespace prim
} // namespace paddle
......@@ -30,6 +30,9 @@
#include "paddle/fluid/prim/utils/static/composite_grad_desc_maker.h"
#include "paddle/fluid/prim/utils/static/desc_tensor.h"
#include "paddle/phi/api/include/tensor.h"
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/common/float16.h"
#include "paddle/phi/core/enforce.h"
namespace paddle {
namespace prim {
......@@ -91,5 +94,100 @@ Tensor multiply<DescTensor>(const Tensor& x, const Tensor& y) {
return out;
}
template <>
Tensor divide<DescTensor>(const Tensor& x, const Tensor& y) {
// Grad infershape
Tensor out = empty<DescTensor>({}, phi::DataType::FLOAT32, paddle::Place());
framework::BlockDesc* block = StaticCompositeContext::Instance().GetBlock();
framework::OpDesc* op = block->AppendOp();
op->SetType("elementwise_div");
op->SetInput("X",
{std::static_pointer_cast<prim::DescTensor>(x.impl())->Name()});
op->SetInput("Y",
{std::static_pointer_cast<prim::DescTensor>(y.impl())->Name()});
op->SetOutput(
"Out", {std::static_pointer_cast<prim::DescTensor>(out.impl())->Name()});
op->CheckAttrs();
op->InferVarType(block);
op->InferShape(*block);
return out;
}
template <>
Tensor full<DescTensor>(paddle::experimental::IntArray shape,
paddle::experimental::Scalar value,
paddle::experimental::DataType dtype,
paddle::platform::Place place) {
// Grad infershape
Tensor out = empty<DescTensor>({}, dtype, place);
framework::BlockDesc* block = StaticCompositeContext::Instance().GetBlock();
framework::OpDesc* op = block->AppendOp();
op->SetType("fill_constant");
op->SetAttr("shape", shape.GetData());
PADDLE_ENFORCE_EQ(
((dtype == paddle::experimental::DataType::FLOAT32) ||
(dtype == paddle::experimental::DataType::FLOAT16)),
true,
phi::errors::InvalidArgument(
"We only support float32/float16 for full, but we got data type: %s",
phi::DataTypeToString(dtype)));
op->SetAttr("value", value.to<float>());
op->SetAttr("dtype", paddle::framework::TransToProtoVarType(dtype));
op->SetOutput(
"Out", {std::static_pointer_cast<prim::DescTensor>(out.impl())->Name()});
op->CheckAttrs();
op->InferVarType(block);
op->InferShape(*block);
return out;
}
template <>
Tensor sum<DescTensor>(Tensor x,
paddle::experimental::IntArray axis,
paddle::experimental::DataType dtype,
bool keepdim) {
// Grad infershape
Tensor out = empty<DescTensor>({}, dtype, paddle::Place());
framework::BlockDesc* block = StaticCompositeContext::Instance().GetBlock();
framework::OpDesc* op = block->AppendOp();
op->SetType("reduce_sum");
op->SetInput("X",
{std::static_pointer_cast<prim::DescTensor>(x.impl())->Name()});
std::vector<int> res;
for (auto value : axis.GetData()) {
res.push_back(static_cast<int>(value));
}
op->SetAttr("dim", res);
op->SetAttr("keep_dim", keepdim);
op->SetAttr("dtype", paddle::framework::TransToProtoVarType(dtype));
op->SetOutput(
"Out", {std::static_pointer_cast<prim::DescTensor>(out.impl())->Name()});
op->CheckAttrs();
op->InferVarType(block);
// TODO(jiabin): This may have runtime shape skip infershape for now.
return out;
}
template <>
Tensor reshape<DescTensor>(Tensor x, paddle::experimental::IntArray shape) {
// Grad infershape
Tensor out = empty<DescTensor>({}, x.dtype(), paddle::Place());
framework::BlockDesc* block = StaticCompositeContext::Instance().GetBlock();
framework::OpDesc* op = block->AppendOp();
op->SetType("reshape");
op->SetInput("X",
{std::static_pointer_cast<prim::DescTensor>(x.impl())->Name()});
std::vector<int> res;
for (auto value : shape.GetData()) {
// TODO(jiabin): This cast is not safe for now, find a way to handle this.
res.push_back(static_cast<int>(value));
}
op->SetAttr("shape", res);
op->SetOutput(
"Out", {std::static_pointer_cast<prim::DescTensor>(out.impl())->Name()});
op->CheckAttrs();
op->InferVarType(block);
// TODO(jiabin): This may have runtime shape skip infershape for now.
return out;
}
} // namespace prim
} // namespace paddle
......@@ -38,6 +38,9 @@ Tensor empty_like<Tensor>(const paddle::experimental::Tensor& x,
}
return empty_like_ad_func(x, dtype, place);
}
template <>
void by_pass<Tensor>(const paddle::experimental::Tensor& x, Tensor* out) {
out->set_impl(x.impl());
}
} // namespace prim
} // namespace paddle
......@@ -47,5 +47,23 @@ Tensor empty_like<DescTensor>(const Tensor& x,
paddle::experimental::IntArray(x.shape()), x.dtype(), paddle::Place());
}
template <>
void by_pass<DescTensor>(const paddle::experimental::Tensor& x,
paddle::experimental::Tensor* out) {
Tensor new_out =
empty<DescTensor>({}, phi::DataType::FLOAT32, paddle::Place());
framework::BlockDesc* block = StaticCompositeContext::Instance().GetBlock();
framework::OpDesc* op = block->AppendOp();
op->SetType("assign");
op->SetInput("X",
{std::static_pointer_cast<prim::DescTensor>(x.impl())->Name()});
op->SetOutput(
"Out", {std::static_pointer_cast<prim::DescTensor>(out->impl())->Name()});
op->CheckAttrs();
op->InferVarType(block);
op->InferShape(*block);
out->set_impl(new_out.impl());
}
} // namespace prim
} // namespace paddle
......@@ -19,6 +19,8 @@
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/common/int_array.h"
#include "paddle/phi/common/place.h"
#include "paddle/phi/core/ddim.h"
using IntArray = paddle::experimental::IntArray;
namespace paddle {
namespace prim {
// We put some api like utils here
......@@ -31,6 +33,46 @@ template <typename T>
paddle::experimental::Tensor empty_like(const paddle::experimental::Tensor& x,
paddle::experimental::DataType dtype,
const paddle::Place& place);
template <typename T>
void by_pass(const paddle::experimental::Tensor& x,
paddle::experimental::Tensor* out);
// These method don't need to be specified
static phi::DDim get_reduce_dims(const phi::DDim& x_dims,
const phi::DDim& y_dims) {
std::vector<int64_t> result;
PADDLE_ENFORCE_GE(phi::product(x_dims),
phi::product(y_dims),
phi::errors::InvalidArgument(
"Only x_dims >= y_dims is accepted for "
"get_reduce_dims, but we got x_dims: %s, y_dims: %s",
x_dims,
y_dims));
int bat = x_dims.size() - y_dims.size();
for (int i = 0; i < bat; ++i) {
result.push_back(i);
}
for (int i = 0; i < y_dims.size(); ++i) {
if (y_dims[i] == 1) {
result.push_back(i + bat);
} else {
PADDLE_ENFORCE_EQ(
y_dims[i],
x_dims[i + bat],
platform::errors::InvalidArgument(
"ReduceDims dimension mismatch. Operands could "
"not be broadcast together with the shape of x_dims = [%s] and "
"the shape of y_dims = [%s]. Received [%d] in X is not equal to "
"[%d] in Y at i:%d.",
x_dims,
y_dims,
x_dims[i + bat],
y_dims[i],
i));
}
}
auto res_dims = phi::make_ddim(result);
VLOG(4) << "Reduce Dims is: " << res_dims;
return res_dims;
}
} // namespace prim
} // namespace paddle
......@@ -19,7 +19,7 @@ set(prim_generated_deps final_dygraph_function final_dygraph_node
dygraph_function dygraph_node)
cc_test_old(
test_static_prim
test_comp_static
SRCS
test_static_prim.cc
DEPS
......@@ -37,7 +37,7 @@ cc_test_old(
if(NOT (NOT WITH_PYTHON AND ON_INFER))
cc_test_old(
test_eager_prim
test_comp_eager
SRCS
test_eager_prim.cc
DEPS
......
......@@ -45,6 +45,8 @@ class DescTensor : public phi::ExtendedTensor,
framework::VarDesc* get_ptr() { return desc_ptr_; }
const phi::Place& place() const override { return place_; }
// TODO(jiabin): override more operators here.
private:
......@@ -55,6 +57,7 @@ class DescTensor : public phi::ExtendedTensor,
// we can inherient from ExtendedTensor Rmove this when we make VarDesc's as
// same as Tensor, or make Tensor's dims more lightly.
mutable phi::DDim dims_;
phi::Place place_;
};
} // namespace prim
......
......@@ -42,6 +42,7 @@
kernel :
func : add_grad
no_need_buffer : x, y
composite : add_grad(Tensor x, Tensor y, Tensor out_grad, int axis)
backward : add_double_grad
inplace : (out_grad -> x_grad)
......@@ -375,6 +376,7 @@
param : [x, y]
kernel :
func : divide_grad
composite : divide_grad(Tensor x, Tensor y, Tensor out, Tensor out_grad, int axis = -1)
backward : divide_double_grad
- backward_op : dropout_grad
......@@ -1325,6 +1327,7 @@
kernel :
func : subtract_grad
no_need_buffer : x, y
composite : subtract_grad(Tensor x, Tensor y, Tensor out_grad, int axis)
backward : subtract_double_grad
inplace : (out_grad -> x_grad)
......
......@@ -14,10 +14,10 @@ limitations under the License. */
#pragma once
#include "paddle/phi/common/place.h"
#include "paddle/phi/core/allocator.h"
#include "paddle/phi/core/tensor_base.h"
#include "paddle/phi/core/tensor_meta.h"
namespace phi {
/// \brief The ExtendedTensor is a interface for custom designed class.
......
......@@ -24,6 +24,7 @@ from predictor_utils import PredictorTools
import paddle
import paddle.fluid as fluid
from paddle.fluid import core
from paddle.jit.translated_layer import INFER_MODEL_SUFFIX, INFER_PARAMS_SUFFIX
place = (
......@@ -234,6 +235,18 @@ class TestBert(unittest.TestCase):
self.verify_predict()
def test_train_composite(self):
core.set_prim_enabled(True)
static_loss, static_ppl = self.train_static(
self.bert_config, self.data_reader
)
core.set_prim_enabled(False)
dygraph_loss, dygraph_ppl = self.train_dygraph(
self.bert_config, self.data_reader
)
np.testing.assert_allclose(static_loss, dygraph_loss, rtol=1e-05)
np.testing.assert_allclose(static_ppl, dygraph_ppl, rtol=1e-05)
def verify_predict(self):
for data in self.data_reader.data_generator()():
dygraph_pred_res = self.predict_dygraph(self.bert_config, data)
......
......@@ -23,6 +23,7 @@ from predictor_utils import PredictorTools
import paddle
import paddle.fluid as fluid
from paddle.fluid import core
from paddle.jit.translated_layer import INFER_MODEL_SUFFIX, INFER_PARAMS_SUFFIX
from paddle.nn import BatchNorm
......@@ -425,6 +426,21 @@ class TestResnet(unittest.TestCase):
)
self.verify_predict()
def test_resnet_composite(self):
core.set_prim_enabled(True)
static_loss = self.train(to_static=True)
core.set_prim_enabled(False)
dygraph_loss = self.train(to_static=True)
np.testing.assert_allclose(
static_loss,
dygraph_loss,
rtol=1e-05,
err_msg='static_loss: {} \n dygraph_loss: {}'.format(
static_loss, dygraph_loss
),
)
core.set_prim_enabled(False)
def test_in_static_mode_mkldnn(self):
fluid.set_flags({'FLAGS_use_mkldnn': True})
try:
......
......@@ -20,6 +20,7 @@ from test_resnet import SEED, ResNet, optimizer_setting
import paddle
import paddle.fluid as fluid
from paddle.fluid import core
# NOTE: Reduce batch_size from 8 to 2 to avoid unittest timeout.
batch_size = 2
......@@ -128,6 +129,20 @@ class TestResnet(unittest.TestCase):
),
)
def test_resnet_composite(self):
core.set_prim_enabled(True)
static_loss = self.train(to_static=True)
core.set_prim_enabled(False)
dygraph_loss = self.train(to_static=False)
np.testing.assert_allclose(
static_loss,
dygraph_loss,
rtol=1e-05,
err_msg='static_loss: {} \n dygraph_loss: {}'.format(
static_loss, dygraph_loss
),
)
if __name__ == '__main__':
unittest.main()
......@@ -20,6 +20,7 @@ from test_resnet import SEED, ResNet, optimizer_setting
import paddle
import paddle.fluid as fluid
from paddle.fluid import core
# NOTE: Reduce batch_size from 8 to 2 to avoid unittest timeout.
batch_size = 2
......@@ -134,6 +135,23 @@ class TestResnet(unittest.TestCase):
),
)
def test_resnet_composite(self):
if fluid.is_compiled_with_cuda():
core.set_prim_enabled(True)
static_loss = self.train(to_static=True)
core.set_prim_enabled(False)
dygraph_loss = self.train(to_static=False)
# NOTE: In pure fp16 training, loss is not stable, so we enlarge atol here.
np.testing.assert_allclose(
static_loss,
dygraph_loss,
rtol=1e-05,
atol=0.001,
err_msg='static_loss: {} \n dygraph_loss: {}'.format(
static_loss, dygraph_loss
),
)
if __name__ == '__main__':
unittest.main()
......@@ -22,6 +22,7 @@ import numpy as np
from predictor_utils import PredictorTools
import paddle
from paddle.fluid import core
SEED = 2020
IMAGENET1000 = 1281167
......@@ -424,6 +425,20 @@ class TestResnet(unittest.TestCase):
)
self.verify_predict()
def test_resnet_composite(self):
core.set_prim_enabled(True)
static_loss = self.train(to_static=True)
core.set_prim_enabled(False)
dygraph_loss = self.train(to_static=False)
np.testing.assert_allclose(
static_loss,
dygraph_loss,
rtol=1e-05,
err_msg='static_loss: {} \n dygraph_loss: {}'.format(
static_loss, dygraph_loss
),
)
def test_in_static_mode_mkldnn(self):
paddle.fluid.set_flags({'FLAGS_use_mkldnn': True})
try:
......
......@@ -9,4 +9,7 @@ foreach(TEST_OP ${TEST_OPS})
py_test_modules(${TEST_OP} MODULES ${TEST_OP} ENVS ${GC_ENVS})
endforeach()
set_tests_properties(test_eager_tanh_grad_comp PROPERTIES TIMEOUT 60)
set_tests_properties(test_comp_eager_tanh_grad PROPERTIES TIMEOUT 60)
set_tests_properties(test_comp_eager_div_grad PROPERTIES TIMEOUT 60)
set_tests_properties(test_comp_eager_add_grad PROPERTIES TIMEOUT 60)
set_tests_properties(test_comp_eager_sub_grad PROPERTIES TIMEOUT 60)
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import parameterized as param
import paddle
from paddle.fluid import core
core.set_prim_enabled(True)
@param.parameterized_class(
('primal0', 'primal1', 'dtype'),
[
(
np.random.rand(2, 3, 4),
np.random.rand(2, 3, 4),
np.float32,
),
(
np.random.rand(2, 3, 3, 4),
np.random.rand(3, 1, 4),
np.float32,
),
(
np.random.rand(2, 3, 3, 4),
np.random.rand(2, 3, 1, 4),
np.float32,
),
],
)
class TestTanhGradComp(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.primal0 = cls.primal0.astype(cls.dtype)
cls.primal1 = cls.primal1.astype(cls.dtype)
def setUp(self):
paddle.enable_static()
def tearDown(self):
paddle.disable_static()
def test_tanh_grad_comp(self):
def actual(primal0, primal1):
core.set_prim_enabled(True)
paddle.disable_static()
x = paddle.to_tensor(primal0, dtype='float32', stop_gradient=False)
y = paddle.to_tensor(primal1, dtype='float32', stop_gradient=False)
x.stop_gradient = False
y.stop_gradient = False
out = paddle.add(x, y)
res = paddle.grad(out, [x, y], create_graph=True, retain_graph=True)
return res[0].numpy(), res[1].numpy()
def desired(primal0, primal1):
core.set_prim_enabled(False)
paddle.disable_static()
x = paddle.to_tensor(primal0, dtype='float32', stop_gradient=False)
y = paddle.to_tensor(primal1, dtype='float32', stop_gradient=False)
x.stop_gradient = False
y.stop_gradient = False
out = paddle.add(x, y)
res = paddle.grad(out, [x, y], create_graph=True, retain_graph=True)
return res[0].numpy(), res[1].numpy()
dx, dy = actual(self.primal0, self.primal1)
ddx, ddy = desired(self.primal0, self.primal1)
np.testing.assert_allclose(
actual=dx,
desired=ddx,
rtol=1e-6,
atol=0,
)
np.testing.assert_allclose(
actual=dy,
desired=ddy,
rtol=1e-6,
atol=0,
)
core.set_prim_enabled(False)
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import parameterized as param
import paddle
from paddle.fluid import core
core.set_prim_enabled(True)
@param.parameterized_class(
('primal0', 'primal1', 'dtype'),
[
(
np.random.rand(2, 3, 4),
np.random.rand(2, 3, 4),
np.float32,
),
(
np.random.rand(2, 3, 3, 4),
np.random.rand(3, 1, 4),
np.float32,
),
(
np.random.rand(2, 3, 3, 4),
np.random.rand(2, 3, 1, 4),
np.float32,
),
],
)
class TestTanhGradComp(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.primal0 = cls.primal0.astype(cls.dtype)
cls.primal1 = cls.primal1.astype(cls.dtype)
def setUp(self):
paddle.enable_static()
def tearDown(self):
paddle.disable_static()
def test_tanh_grad_comp(self):
def actual(primal0, primal1):
core.set_prim_enabled(True)
paddle.disable_static()
x = paddle.to_tensor(primal0, dtype='float32', stop_gradient=False)
y = paddle.to_tensor(primal1, dtype='float32', stop_gradient=False)
x.stop_gradient = False
y.stop_gradient = False
out = paddle.divide(x, y)
res = paddle.grad(out, [x, y], create_graph=True, retain_graph=True)
return res[0].numpy(), res[1].numpy()
def desired(primal0, primal1):
core.set_prim_enabled(False)
paddle.disable_static()
x = paddle.to_tensor(primal0, dtype='float32', stop_gradient=False)
y = paddle.to_tensor(primal1, dtype='float32', stop_gradient=False)
x.stop_gradient = False
y.stop_gradient = False
out = paddle.divide(x, y)
res = paddle.grad(out, [x, y], create_graph=True, retain_graph=True)
return res[0].numpy(), res[1].numpy()
dx, dy = actual(self.primal0, self.primal1)
ddx, ddy = desired(self.primal0, self.primal1)
np.testing.assert_allclose(
actual=dx,
desired=ddx,
rtol=1e-6,
atol=0,
)
np.testing.assert_allclose(
actual=dy,
desired=ddy,
rtol=1e-6,
atol=0,
)
core.set_prim_enabled(False)
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import parameterized as param
import paddle
from paddle.fluid import core
core.set_prim_enabled(True)
@param.parameterized_class(
('primal0', 'primal1', 'dtype'),
[
(
np.random.rand(2, 3, 4),
np.random.rand(2, 3, 4),
np.float32,
),
(
np.random.rand(2, 3, 3, 4),
np.random.rand(3, 1, 4),
np.float32,
),
(
np.random.rand(2, 3, 3, 4),
np.random.rand(2, 3, 1, 4),
np.float32,
),
],
)
class TestTanhGradComp(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.primal0 = cls.primal0.astype(cls.dtype)
cls.primal1 = cls.primal1.astype(cls.dtype)
def setUp(self):
paddle.enable_static()
def tearDown(self):
paddle.disable_static()
def test_tanh_grad_comp(self):
def actual(primal0, primal1):
core.set_prim_enabled(True)
paddle.disable_static()
x = paddle.to_tensor(primal0, dtype='float32', stop_gradient=False)
y = paddle.to_tensor(primal1, dtype='float32', stop_gradient=False)
x.stop_gradient = False
y.stop_gradient = False
out = paddle.subtract(x, y)
res = paddle.grad(out, [x, y], create_graph=True, retain_graph=True)
return res[0].numpy(), res[1].numpy()
def desired(primal0, primal1):
core.set_prim_enabled(False)
paddle.disable_static()
x = paddle.to_tensor(primal0, dtype='float32', stop_gradient=False)
y = paddle.to_tensor(primal1, dtype='float32', stop_gradient=False)
x.stop_gradient = False
y.stop_gradient = False
out = paddle.subtract(x, y)
res = paddle.grad(out, [x, y], create_graph=True, retain_graph=True)
return res[0].numpy(), res[1].numpy()
dx, dy = actual(self.primal0, self.primal1)
ddx, ddy = desired(self.primal0, self.primal1)
np.testing.assert_allclose(
actual=dx,
desired=ddx,
rtol=1e-6,
atol=0,
)
np.testing.assert_allclose(
actual=dy,
desired=ddy,
rtol=1e-6,
atol=0,
)
core.set_prim_enabled(False)
if __name__ == '__main__':
unittest.main()
......@@ -9,4 +9,8 @@ foreach(TEST_OP ${TEST_OPS})
py_test_modules(${TEST_OP} MODULES ${TEST_OP} ENVS ${GC_ENVS})
endforeach()
set_tests_properties(test_tanh_grad_comp PROPERTIES TIMEOUT 60)
set_tests_properties(test_comp_tanh_grad PROPERTIES TIMEOUT 60)
set_tests_properties(test_comp_div_grad PROPERTIES TIMEOUT 60)
set_tests_properties(test_comp_add_grad PROPERTIES TIMEOUT 60)
set_tests_properties(test_comp_sub_grad PROPERTIES TIMEOUT 60)
set_tests_properties(test_comp_add_tanh_grad PROPERTIES TIMEOUT 60)
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import parameterized as param
import paddle
from paddle.fluid import core
@param.parameterized_class(
('primal0', 'primal1', 'dtype'),
[
(
np.random.rand(2, 3, 4),
np.random.rand(2, 3, 4),
np.float32,
),
(
np.random.rand(2, 3, 3, 4),
np.random.rand(3, 1, 4),
np.float32,
),
(
np.random.rand(2, 3, 3, 4),
np.random.rand(2, 3, 1, 4),
np.float32,
),
],
)
class TestDivGradComp(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.primal0 = cls.primal0.astype(cls.dtype)
cls.primal1 = cls.primal1.astype(cls.dtype)
def setUp(self):
paddle.enable_static()
def tearDown(self):
paddle.disable_static()
def test_tanh_grad_comp(self):
def actual(primal0, primal1):
core.set_prim_enabled(True)
mp, sp = paddle.static.Program(), paddle.static.Program()
with paddle.static.program_guard(mp, sp):
x = paddle.static.data('primal0', primal0.shape, primal0.dtype)
y = paddle.static.data('primal1', primal1.shape, primal1.dtype)
x.stop_gradient = False
y.stop_gradient = False
z = paddle.add(x, y)
out = paddle.tanh(z)
res = paddle.static.gradients([out], [x, y])
exe = paddle.static.Executor()
exe.run(sp)
out = exe.run(
program=mp,
feed={
'primal0': primal0,
'primal1': primal1,
},
fetch_list=[res[0].name, res[1].name],
)
return out[0], out[1]
def desired(primal0, primal1):
core.set_prim_enabled(False)
mp, sp = paddle.static.Program(), paddle.static.Program()
with paddle.static.program_guard(mp, sp):
x = paddle.static.data(
'primal0', self.primal0.shape, self.primal0.dtype
)
y = paddle.static.data(
'primal1', self.primal1.shape, self.primal1.dtype
)
x.stop_gradient = False
y.stop_gradient = False
z = paddle.add(x, y)
out = paddle.tanh(z)
res = paddle.static.gradients([out], [x, y])
exe = paddle.static.Executor()
exe.run(sp)
out = exe.run(
program=mp,
feed={
'primal0': self.primal0,
'primal1': self.primal1,
},
fetch_list=[res[0].name, res[1].name],
)
return out[0], out[1]
dx, dy = actual(self.primal0, self.primal1)
ddx, ddy = desired(self.primal0, self.primal1)
np.testing.assert_allclose(
actual=dx,
desired=ddx,
rtol=1e-6,
atol=0,
)
np.testing.assert_allclose(
actual=dy,
desired=ddy,
rtol=1e-6,
atol=0,
)
core.set_prim_enabled(False)
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import parameterized as param
import paddle
from paddle.fluid import core
@param.parameterized_class(
('primal0', 'primal1', 'dtype'),
[
(
np.random.rand(2, 3, 4),
np.random.rand(2, 3, 4),
np.float32,
),
(
np.random.rand(2, 3, 3, 4),
np.random.rand(3, 1, 4),
np.float32,
),
(
np.random.rand(2, 3, 3, 4),
np.random.rand(2, 3, 1, 4),
np.float32,
),
],
)
class TestDivGradComp(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.primal0 = cls.primal0.astype(cls.dtype)
cls.primal1 = cls.primal1.astype(cls.dtype)
def setUp(self):
paddle.enable_static()
def tearDown(self):
paddle.disable_static()
def test_tanh_grad_comp(self):
def actual(primal0, primal1):
core.set_prim_enabled(True)
mp, sp = paddle.static.Program(), paddle.static.Program()
with paddle.static.program_guard(mp, sp):
x = paddle.static.data('primal0', primal0.shape, primal0.dtype)
y = paddle.static.data('primal1', primal1.shape, primal1.dtype)
x.stop_gradient = False
y.stop_gradient = False
z = paddle.add(x, y)
res = paddle.static.gradients([z], [x, y])
exe = paddle.static.Executor()
exe.run(sp)
out = exe.run(
program=mp,
feed={
'primal0': primal0,
'primal1': primal1,
},
fetch_list=[res[0].name, res[1].name],
)
return out[0], out[1]
def desired(primal0, primal1):
core.set_prim_enabled(False)
mp, sp = paddle.static.Program(), paddle.static.Program()
with paddle.static.program_guard(mp, sp):
x = paddle.static.data(
'primal0', self.primal0.shape, self.primal0.dtype
)
y = paddle.static.data(
'primal1', self.primal1.shape, self.primal1.dtype
)
x.stop_gradient = False
y.stop_gradient = False
z = paddle.add(x, y)
res = paddle.static.gradients([z], [x, y])
exe = paddle.static.Executor()
exe.run(sp)
out = exe.run(
program=mp,
feed={
'primal0': self.primal0,
'primal1': self.primal1,
},
fetch_list=[res[0].name, res[1].name],
)
return out[0], out[1]
dx, dy = actual(self.primal0, self.primal1)
ddx, ddy = desired(self.primal0, self.primal1)
np.testing.assert_allclose(
actual=dx,
desired=ddx,
rtol=1e-6,
atol=0,
)
np.testing.assert_allclose(
actual=dy,
desired=ddy,
rtol=1e-6,
atol=0,
)
core.set_prim_enabled(False)
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import parameterized as param
import paddle
from paddle.fluid import core
@param.parameterized_class(
('primal0', 'primal1', 'dtype'),
[
(
np.random.rand(2, 3, 4),
np.random.rand(2, 3, 4),
np.float32,
),
(
np.random.rand(2, 3, 3, 4),
np.random.rand(3, 1, 4),
np.float32,
),
(
np.random.rand(2, 3, 3, 4),
np.random.rand(2, 3, 1, 4),
np.float32,
),
],
)
class TestDivGradComp(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.primal0 = cls.primal0.astype(cls.dtype)
cls.primal1 = cls.primal1.astype(cls.dtype)
def setUp(self):
paddle.enable_static()
def tearDown(self):
paddle.disable_static()
def test_tanh_grad_comp(self):
def actual(primal0, primal1):
core.set_prim_enabled(True)
mp, sp = paddle.static.Program(), paddle.static.Program()
with paddle.static.program_guard(mp, sp):
x = paddle.static.data('primal0', primal0.shape, primal0.dtype)
y = paddle.static.data('primal1', primal1.shape, primal1.dtype)
x.stop_gradient = False
y.stop_gradient = False
z = paddle.divide(x, y)
res = paddle.static.gradients([z], [x, y])
exe = paddle.static.Executor()
exe.run(sp)
out = exe.run(
program=mp,
feed={
'primal0': primal0,
'primal1': primal1,
},
fetch_list=[res[0].name, res[1].name],
)
return out[0], out[1]
def desired(primal0, primal1):
core.set_prim_enabled(False)
mp, sp = paddle.static.Program(), paddle.static.Program()
with paddle.static.program_guard(mp, sp):
x = paddle.static.data(
'primal0', self.primal0.shape, self.primal0.dtype
)
y = paddle.static.data(
'primal1', self.primal1.shape, self.primal1.dtype
)
x.stop_gradient = False
y.stop_gradient = False
z = paddle.divide(x, y)
res = paddle.static.gradients([z], [x, y])
exe = paddle.static.Executor()
exe.run(sp)
out = exe.run(
program=mp,
feed={
'primal0': self.primal0,
'primal1': self.primal1,
},
fetch_list=[res[0].name, res[1].name],
)
return out[0], out[1]
dx, dy = actual(self.primal0, self.primal1)
ddx, ddy = desired(self.primal0, self.primal1)
np.testing.assert_allclose(
actual=dx,
desired=ddx,
rtol=1e-6,
atol=0,
)
np.testing.assert_allclose(
actual=dy,
desired=ddy,
rtol=1e-6,
atol=0,
)
core.set_prim_enabled(False)
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import parameterized as param
import paddle
from paddle.fluid import core
@param.parameterized_class(
('primal0', 'primal1', 'dtype'),
[
(
np.random.rand(2, 3, 4),
np.random.rand(2, 3, 4),
np.float32,
),
(
np.random.rand(2, 3, 3, 4),
np.random.rand(3, 1, 4),
np.float32,
),
(
np.random.rand(2, 3, 3, 4),
np.random.rand(2, 3, 1, 4),
np.float32,
),
],
)
class TestDivGradComp(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.primal0 = cls.primal0.astype(cls.dtype)
cls.primal1 = cls.primal1.astype(cls.dtype)
def setUp(self):
paddle.enable_static()
def tearDown(self):
paddle.disable_static()
def test_tanh_grad_comp(self):
def actual(primal0, primal1):
core.set_prim_enabled(True)
mp, sp = paddle.static.Program(), paddle.static.Program()
with paddle.static.program_guard(mp, sp):
x = paddle.static.data('primal0', primal0.shape, primal0.dtype)
y = paddle.static.data('primal1', primal1.shape, primal1.dtype)
x.stop_gradient = False
y.stop_gradient = False
out = paddle.subtract(x, y)
res = paddle.static.gradients([out], [x, y])
exe = paddle.static.Executor()
exe.run(sp)
out = exe.run(
program=mp,
feed={
'primal0': primal0,
'primal1': primal1,
},
fetch_list=[res[0].name, res[1].name],
)
return out[0], out[1]
def desired(primal0, primal1):
core.set_prim_enabled(False)
mp, sp = paddle.static.Program(), paddle.static.Program()
with paddle.static.program_guard(mp, sp):
x = paddle.static.data(
'primal0', self.primal0.shape, self.primal0.dtype
)
y = paddle.static.data(
'primal1', self.primal1.shape, self.primal1.dtype
)
x.stop_gradient = False
y.stop_gradient = False
out = paddle.subtract(x, y)
res = paddle.static.gradients([out], [x, y])
exe = paddle.static.Executor()
exe.run(sp)
out = exe.run(
program=mp,
feed={
'primal0': self.primal0,
'primal1': self.primal1,
},
fetch_list=[res[0].name, res[1].name],
)
return out[0], out[1]
dx, dy = actual(self.primal0, self.primal1)
ddx, ddy = desired(self.primal0, self.primal1)
np.testing.assert_allclose(
actual=dx,
desired=ddx,
rtol=1e-6,
atol=0,
)
np.testing.assert_allclose(
actual=dy,
desired=ddy,
rtol=1e-6,
atol=0,
)
core.set_prim_enabled(False)
if __name__ == '__main__':
unittest.main()
......@@ -60,7 +60,7 @@ class TestTanhGradComp(unittest.TestCase):
return exe.run(
program=mp,
feed={'primal': primal, 'cotangent': cotangent},
fetch_list=mp.blocks[0].ops[-1].output('Out')[0],
fetch_list=[x_cotangent[0].name],
)[0]
def desired(primal, cotangent):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册