1*ec9796c4SHansol Suh #pragma once 2*ec9796c4SHansol Suh 3*ec9796c4SHansol Suh #include <petsctao.h> 4*ec9796c4SHansol Suh #include <petscsf.h> 5*ec9796c4SHansol Suh #include <petscdevice.h> 6*ec9796c4SHansol Suh #include <petscdevice_cupm.h> 7*ec9796c4SHansol Suh 8*ec9796c4SHansol Suh /* 9*ec9796c4SHansol Suh User-defined application context - contains data needed by the 10*ec9796c4SHansol Suh application-provided call-back routines that evaluate the function, 11*ec9796c4SHansol Suh gradient, and hessian. 12*ec9796c4SHansol Suh */ 13*ec9796c4SHansol Suh 14*ec9796c4SHansol Suh typedef struct _Rosenbrock { 15*ec9796c4SHansol Suh PetscInt bs; // each block of bs variables is one chained multidimensional rosenbrock problem 16*ec9796c4SHansol Suh PetscInt i_start, i_end; 17*ec9796c4SHansol Suh PetscInt c_start, c_end; 18*ec9796c4SHansol Suh PetscReal alpha; // condition parameter 19*ec9796c4SHansol Suh } Rosenbrock; 20*ec9796c4SHansol Suh 21*ec9796c4SHansol Suh typedef struct _AppCtx *AppCtx; 22*ec9796c4SHansol Suh struct _AppCtx { 23*ec9796c4SHansol Suh MPI_Comm comm; 24*ec9796c4SHansol Suh PetscInt n; /* dimension */ 25*ec9796c4SHansol Suh PetscInt n_local; 26*ec9796c4SHansol Suh PetscInt n_local_comp; 27*ec9796c4SHansol Suh Rosenbrock problem; 28*ec9796c4SHansol Suh Vec Hvalues; /* vector for writing COO values of this MPI process */ 29*ec9796c4SHansol Suh Vec gvalues; /* vector for writing gradient values of this mpi process */ 30*ec9796c4SHansol Suh Vec fvector; 31*ec9796c4SHansol Suh PetscSF off_process_scatter; 32*ec9796c4SHansol Suh PetscSF gscatter; 33*ec9796c4SHansol Suh Vec off_process_values; /* buffer for off-process values if chained */ 34*ec9796c4SHansol Suh PetscBool test_lmvm; 35*ec9796c4SHansol Suh PetscLogEvent event_f, event_g, event_fg; 36*ec9796c4SHansol Suh }; 37*ec9796c4SHansol Suh 38*ec9796c4SHansol Suh /* -------------- User-defined routines ---------- */ 39*ec9796c4SHansol Suh 40*ec9796c4SHansol Suh static PETSC_HOSTDEVICE_INLINE_DECL PetscReal RosenbrockObjective(PetscScalar alpha, PetscScalar x_1, PetscScalar x_2) 41*ec9796c4SHansol Suh { 42*ec9796c4SHansol Suh PetscScalar d = x_2 - x_1 * x_1; 43*ec9796c4SHansol Suh PetscScalar e = 1.0 - x_1; 44*ec9796c4SHansol Suh return alpha * d * d + e * e; 45*ec9796c4SHansol Suh } 46*ec9796c4SHansol Suh 47*ec9796c4SHansol Suh static const PetscLogDouble RosenbrockObjectiveFlops = 7.0; 48*ec9796c4SHansol Suh 49*ec9796c4SHansol Suh static PETSC_HOSTDEVICE_INLINE_DECL void RosenbrockGradient(PetscScalar alpha, PetscScalar x_1, PetscScalar x_2, PetscScalar g[2]) 50*ec9796c4SHansol Suh { 51*ec9796c4SHansol Suh PetscScalar d = x_2 - x_1 * x_1; 52*ec9796c4SHansol Suh PetscScalar e = 1.0 - x_1; 53*ec9796c4SHansol Suh PetscScalar g2 = alpha * d * 2.0; 54*ec9796c4SHansol Suh 55*ec9796c4SHansol Suh g[0] = -2.0 * x_1 * g2 - 2.0 * e; 56*ec9796c4SHansol Suh g[1] = g2; 57*ec9796c4SHansol Suh } 58*ec9796c4SHansol Suh 59*ec9796c4SHansol Suh static const PetscInt RosenbrockGradientFlops = 9.0; 60*ec9796c4SHansol Suh 61*ec9796c4SHansol Suh static PETSC_HOSTDEVICE_INLINE_DECL PetscReal RosenbrockObjectiveGradient(PetscScalar alpha, PetscScalar x_1, PetscScalar x_2, PetscScalar g[2]) 62*ec9796c4SHansol Suh { 63*ec9796c4SHansol Suh PetscScalar d = x_2 - x_1 * x_1; 64*ec9796c4SHansol Suh PetscScalar e = 1.0 - x_1; 65*ec9796c4SHansol Suh PetscScalar ad = alpha * d; 66*ec9796c4SHansol Suh PetscScalar g2 = ad * 2.0; 67*ec9796c4SHansol Suh 68*ec9796c4SHansol Suh g[0] = -2.0 * x_1 * g2 - 2.0 * e; 69*ec9796c4SHansol Suh g[1] = g2; 70*ec9796c4SHansol Suh return ad * d + e * e; 71*ec9796c4SHansol Suh } 72*ec9796c4SHansol Suh 73*ec9796c4SHansol Suh static const PetscLogDouble RosenbrockObjectiveGradientFlops = 12.0; 74*ec9796c4SHansol Suh 75*ec9796c4SHansol Suh static PETSC_HOSTDEVICE_INLINE_DECL void RosenbrockHessian(PetscScalar alpha, PetscScalar x_1, PetscScalar x_2, PetscScalar h[4]) 76*ec9796c4SHansol Suh { 77*ec9796c4SHansol Suh PetscScalar d = x_2 - x_1 * x_1; 78*ec9796c4SHansol Suh PetscScalar g2 = alpha * d * 2.0; 79*ec9796c4SHansol Suh PetscScalar h2 = -4.0 * alpha * x_1; 80*ec9796c4SHansol Suh 81*ec9796c4SHansol Suh h[0] = -2.0 * (g2 + x_1 * h2) + 2.0; 82*ec9796c4SHansol Suh h[1] = h[2] = h2; 83*ec9796c4SHansol Suh h[3] = 2.0 * alpha; 84*ec9796c4SHansol Suh } 85*ec9796c4SHansol Suh 86*ec9796c4SHansol Suh static const PetscLogDouble RosenbrockHessianFlops = 11.0; 87*ec9796c4SHansol Suh 88*ec9796c4SHansol Suh static PetscErrorCode AppCtxCreate(MPI_Comm comm, AppCtx *ctx) 89*ec9796c4SHansol Suh { 90*ec9796c4SHansol Suh AppCtx user; 91*ec9796c4SHansol Suh PetscDeviceContext dctx; 92*ec9796c4SHansol Suh 93*ec9796c4SHansol Suh PetscFunctionBegin; 94*ec9796c4SHansol Suh PetscCall(PetscNew(ctx)); 95*ec9796c4SHansol Suh user = *ctx; 96*ec9796c4SHansol Suh user->comm = PETSC_COMM_WORLD; 97*ec9796c4SHansol Suh 98*ec9796c4SHansol Suh /* Initialize problem parameters */ 99*ec9796c4SHansol Suh user->n = 2; 100*ec9796c4SHansol Suh user->problem.alpha = 99.0; 101*ec9796c4SHansol Suh user->problem.bs = 2; // bs = 2 is block Rosenbrock, bs = n is chained Rosenbrock 102*ec9796c4SHansol Suh user->test_lmvm = PETSC_FALSE; 103*ec9796c4SHansol Suh /* Check for command line arguments to override defaults */ 104*ec9796c4SHansol Suh PetscOptionsBegin(user->comm, NULL, "Rosenbrock example", NULL); 105*ec9796c4SHansol Suh PetscCall(PetscOptionsInt("-n", "Rosenbrock problem size", NULL, user->n, &user->n, NULL)); 106*ec9796c4SHansol Suh PetscCall(PetscOptionsInt("-bs", "Rosenbrock block size (2 <= bs <= n)", NULL, user->problem.bs, &user->problem.bs, NULL)); 107*ec9796c4SHansol Suh PetscCall(PetscOptionsReal("-alpha", "Rosenbrock off-diagonal coefficient", NULL, user->problem.alpha, &user->problem.alpha, NULL)); 108*ec9796c4SHansol Suh PetscCall(PetscOptionsBool("-test_lmvm", "Test LMVM solve againt LMVM mult", NULL, user->test_lmvm, &user->test_lmvm, NULL)); 109*ec9796c4SHansol Suh PetscOptionsEnd(); 110*ec9796c4SHansol Suh PetscCheck(user->problem.bs >= 1, comm, PETSC_ERR_ARG_INCOMP, "Block size %" PetscInt_FMT " is not bigger than 1", user->problem.bs); 111*ec9796c4SHansol Suh PetscCheck((user->n % user->problem.bs) == 0, comm, PETSC_ERR_ARG_INCOMP, "Block size %" PetscInt_FMT " doest not divide problem size % " PetscInt_FMT, user->problem.bs, user->n); 112*ec9796c4SHansol Suh PetscCall(PetscLogEventRegister("Rbock_Obj", TAO_CLASSID, &user->event_f)); 113*ec9796c4SHansol Suh PetscCall(PetscLogEventRegister("Rbock_Grad", TAO_CLASSID, &user->event_g)); 114*ec9796c4SHansol Suh PetscCall(PetscLogEventRegister("Rbock_ObjGrad", TAO_CLASSID, &user->event_fg)); 115*ec9796c4SHansol Suh PetscCall(PetscDeviceContextGetCurrentContext(&dctx)); 116*ec9796c4SHansol Suh PetscCall(PetscDeviceContextSetUp(dctx)); 117*ec9796c4SHansol Suh PetscFunctionReturn(PETSC_SUCCESS); 118*ec9796c4SHansol Suh } 119*ec9796c4SHansol Suh 120*ec9796c4SHansol Suh static PetscErrorCode AppCtxDestroy(AppCtx *ctx) 121*ec9796c4SHansol Suh { 122*ec9796c4SHansol Suh AppCtx user; 123*ec9796c4SHansol Suh 124*ec9796c4SHansol Suh PetscFunctionBegin; 125*ec9796c4SHansol Suh user = *ctx; 126*ec9796c4SHansol Suh *ctx = NULL; 127*ec9796c4SHansol Suh PetscCall(VecDestroy(&user->Hvalues)); 128*ec9796c4SHansol Suh PetscCall(VecDestroy(&user->gvalues)); 129*ec9796c4SHansol Suh PetscCall(VecDestroy(&user->fvector)); 130*ec9796c4SHansol Suh PetscCall(VecDestroy(&user->off_process_values)); 131*ec9796c4SHansol Suh PetscCall(PetscSFDestroy(&user->off_process_scatter)); 132*ec9796c4SHansol Suh PetscCall(PetscSFDestroy(&user->gscatter)); 133*ec9796c4SHansol Suh PetscCall(PetscFree(user)); 134*ec9796c4SHansol Suh PetscFunctionReturn(PETSC_SUCCESS); 135*ec9796c4SHansol Suh } 136*ec9796c4SHansol Suh 137*ec9796c4SHansol Suh static PetscErrorCode CreateHessian(AppCtx user, Mat *Hessian) 138*ec9796c4SHansol Suh { 139*ec9796c4SHansol Suh Mat H; 140*ec9796c4SHansol Suh PetscLayout layout; 141*ec9796c4SHansol Suh PetscInt i_start, i_end, n_local_comp, nnz_local; 142*ec9796c4SHansol Suh PetscInt c_start, c_end; 143*ec9796c4SHansol Suh PetscInt *coo_i; 144*ec9796c4SHansol Suh PetscInt *coo_j; 145*ec9796c4SHansol Suh PetscInt bs = user->problem.bs; 146*ec9796c4SHansol Suh VecType vec_type; 147*ec9796c4SHansol Suh 148*ec9796c4SHansol Suh PetscFunctionBegin; 149*ec9796c4SHansol Suh /* Partition the optimization variables and the computations. 150*ec9796c4SHansol Suh There are (bs - 1) contributions to the objective function for every (bs) 151*ec9796c4SHansol Suh degrees of freedom. */ 152*ec9796c4SHansol Suh PetscCall(PetscLayoutCreateFromSizes(user->comm, PETSC_DECIDE, user->n, 1, &layout)); 153*ec9796c4SHansol Suh PetscCall(PetscLayoutSetUp(layout)); 154*ec9796c4SHansol Suh PetscCall(PetscLayoutGetRange(layout, &i_start, &i_end)); 155*ec9796c4SHansol Suh user->problem.i_start = i_start; 156*ec9796c4SHansol Suh user->problem.i_end = i_end; 157*ec9796c4SHansol Suh user->n_local = i_end - i_start; 158*ec9796c4SHansol Suh user->problem.c_start = c_start = (i_start / bs) * (bs - 1) + (i_start % bs); 159*ec9796c4SHansol Suh user->problem.c_end = c_end = (i_end / bs) * (bs - 1) + (i_end % bs); 160*ec9796c4SHansol Suh user->n_local_comp = n_local_comp = c_end - c_start; 161*ec9796c4SHansol Suh 162*ec9796c4SHansol Suh PetscCall(MatCreate(user->comm, Hessian)); 163*ec9796c4SHansol Suh H = *Hessian; 164*ec9796c4SHansol Suh PetscCall(MatSetLayouts(H, layout, layout)); 165*ec9796c4SHansol Suh PetscCall(PetscLayoutDestroy(&layout)); 166*ec9796c4SHansol Suh PetscCall(MatSetType(H, MATAIJ)); 167*ec9796c4SHansol Suh PetscCall(MatSetOption(H, MAT_HERMITIAN, PETSC_TRUE)); 168*ec9796c4SHansol Suh PetscCall(MatSetOption(H, MAT_SYMMETRIC, PETSC_TRUE)); 169*ec9796c4SHansol Suh PetscCall(MatSetOption(H, MAT_SYMMETRY_ETERNAL, PETSC_TRUE)); 170*ec9796c4SHansol Suh PetscCall(MatSetOption(H, MAT_STRUCTURALLY_SYMMETRIC, PETSC_TRUE)); 171*ec9796c4SHansol Suh PetscCall(MatSetOption(H, MAT_STRUCTURAL_SYMMETRY_ETERNAL, PETSC_TRUE)); 172*ec9796c4SHansol Suh PetscCall(MatSetFromOptions(H)); /* set from options so that we can change the underlying matrix type */ 173*ec9796c4SHansol Suh 174*ec9796c4SHansol Suh nnz_local = n_local_comp * 4; 175*ec9796c4SHansol Suh PetscCall(PetscMalloc2(nnz_local, &coo_i, nnz_local, &coo_j)); 176*ec9796c4SHansol Suh /* Instead of having one computation thread per row of the matrix, 177*ec9796c4SHansol Suh this example uses one thread per contribution to the objective 178*ec9796c4SHansol Suh function. Each contribution to the objective function relates 179*ec9796c4SHansol Suh two adjacent degrees of freedom, so each contribution to 180*ec9796c4SHansol Suh the objective function adds a 2x2 block into the matrix. 181*ec9796c4SHansol Suh We describe these 2x2 blocks in COO format. */ 182*ec9796c4SHansol Suh for (PetscInt c = c_start, k = 0; c < c_end; c++, k += 4) { 183*ec9796c4SHansol Suh PetscInt i = (c / (bs - 1)) * bs + c % (bs - 1); 184*ec9796c4SHansol Suh 185*ec9796c4SHansol Suh coo_i[k + 0] = i; 186*ec9796c4SHansol Suh coo_i[k + 1] = i; 187*ec9796c4SHansol Suh coo_i[k + 2] = i + 1; 188*ec9796c4SHansol Suh coo_i[k + 3] = i + 1; 189*ec9796c4SHansol Suh 190*ec9796c4SHansol Suh coo_j[k + 0] = i; 191*ec9796c4SHansol Suh coo_j[k + 1] = i + 1; 192*ec9796c4SHansol Suh coo_j[k + 2] = i; 193*ec9796c4SHansol Suh coo_j[k + 3] = i + 1; 194*ec9796c4SHansol Suh } 195*ec9796c4SHansol Suh PetscCall(MatSetPreallocationCOO(H, nnz_local, coo_i, coo_j)); 196*ec9796c4SHansol Suh PetscCall(PetscFree2(coo_i, coo_j)); 197*ec9796c4SHansol Suh 198*ec9796c4SHansol Suh PetscCall(MatGetVecType(H, &vec_type)); 199*ec9796c4SHansol Suh PetscCall(VecCreate(user->comm, &user->Hvalues)); 200*ec9796c4SHansol Suh PetscCall(VecSetSizes(user->Hvalues, nnz_local, PETSC_DETERMINE)); 201*ec9796c4SHansol Suh PetscCall(VecSetType(user->Hvalues, vec_type)); 202*ec9796c4SHansol Suh 203*ec9796c4SHansol Suh // vector to collect contributions to the objective 204*ec9796c4SHansol Suh PetscCall(VecCreate(user->comm, &user->fvector)); 205*ec9796c4SHansol Suh PetscCall(VecSetSizes(user->fvector, user->n_local_comp, PETSC_DETERMINE)); 206*ec9796c4SHansol Suh PetscCall(VecSetType(user->fvector, vec_type)); 207*ec9796c4SHansol Suh 208*ec9796c4SHansol Suh { /* If we are using a device (such as a GPU), run some computations that will 209*ec9796c4SHansol Suh warm up its linear algebra runtime before the problem we actually want 210*ec9796c4SHansol Suh to profile */ 211*ec9796c4SHansol Suh 212*ec9796c4SHansol Suh PetscMemType memtype; 213*ec9796c4SHansol Suh const PetscScalar *a; 214*ec9796c4SHansol Suh 215*ec9796c4SHansol Suh PetscCall(VecGetArrayReadAndMemType(user->fvector, &a, &memtype)); 216*ec9796c4SHansol Suh PetscCall(VecRestoreArrayReadAndMemType(user->fvector, &a)); 217*ec9796c4SHansol Suh 218*ec9796c4SHansol Suh if (memtype == PETSC_MEMTYPE_DEVICE) { 219*ec9796c4SHansol Suh PetscLogStage warmup; 220*ec9796c4SHansol Suh Mat A, AtA; 221*ec9796c4SHansol Suh Vec x, b; 222*ec9796c4SHansol Suh PetscInt warmup_size = 1000; 223*ec9796c4SHansol Suh PetscDeviceContext dctx; 224*ec9796c4SHansol Suh 225*ec9796c4SHansol Suh PetscCall(PetscLogStageRegister("Device Warmup", &warmup)); 226*ec9796c4SHansol Suh PetscCall(PetscLogStageSetActive(warmup, PETSC_FALSE)); 227*ec9796c4SHansol Suh 228*ec9796c4SHansol Suh PetscCall(PetscLogStagePush(warmup)); 229*ec9796c4SHansol Suh PetscCall(MatCreateDenseFromVecType(PETSC_COMM_SELF, vec_type, warmup_size, warmup_size, warmup_size, warmup_size, PETSC_DEFAULT, NULL, &A)); 230*ec9796c4SHansol Suh PetscCall(MatSetRandom(A, NULL)); 231*ec9796c4SHansol Suh PetscCall(MatCreateVecs(A, &x, &b)); 232*ec9796c4SHansol Suh PetscCall(VecSetRandom(x, NULL)); 233*ec9796c4SHansol Suh 234*ec9796c4SHansol Suh PetscCall(MatMult(A, x, b)); 235*ec9796c4SHansol Suh PetscCall(MatTransposeMatMult(A, A, MAT_INITIAL_MATRIX, PETSC_DEFAULT, &AtA)); 236*ec9796c4SHansol Suh PetscCall(MatShift(AtA, (PetscScalar)warmup_size)); 237*ec9796c4SHansol Suh PetscCall(MatSetOption(AtA, MAT_SPD, PETSC_TRUE)); 238*ec9796c4SHansol Suh PetscCall(MatCholeskyFactor(AtA, NULL, NULL)); 239*ec9796c4SHansol Suh PetscCall(MatDestroy(&AtA)); 240*ec9796c4SHansol Suh PetscCall(VecDestroy(&b)); 241*ec9796c4SHansol Suh PetscCall(VecDestroy(&x)); 242*ec9796c4SHansol Suh PetscCall(MatDestroy(&A)); 243*ec9796c4SHansol Suh PetscCall(PetscDeviceContextGetCurrentContext(&dctx)); 244*ec9796c4SHansol Suh PetscCall(PetscDeviceContextSynchronize(dctx)); 245*ec9796c4SHansol Suh PetscCall(PetscLogStagePop()); 246*ec9796c4SHansol Suh } 247*ec9796c4SHansol Suh } 248*ec9796c4SHansol Suh PetscFunctionReturn(PETSC_SUCCESS); 249*ec9796c4SHansol Suh } 250*ec9796c4SHansol Suh 251*ec9796c4SHansol Suh static PetscErrorCode CreateVectors(AppCtx user, Mat H, Vec *solution, Vec *gradient) 252*ec9796c4SHansol Suh { 253*ec9796c4SHansol Suh VecType vec_type; 254*ec9796c4SHansol Suh PetscInt n_coo, *coo_i, i_start, i_end; 255*ec9796c4SHansol Suh Vec x; 256*ec9796c4SHansol Suh PetscInt n_recv; 257*ec9796c4SHansol Suh PetscSFNode recv; 258*ec9796c4SHansol Suh PetscLayout layout; 259*ec9796c4SHansol Suh PetscInt c_start = user->problem.c_start, c_end = user->problem.c_end, bs = user->problem.bs; 260*ec9796c4SHansol Suh 261*ec9796c4SHansol Suh PetscFunctionBegin; 262*ec9796c4SHansol Suh PetscCall(MatCreateVecs(H, solution, gradient)); 263*ec9796c4SHansol Suh x = *solution; 264*ec9796c4SHansol Suh PetscCall(VecGetOwnershipRange(x, &i_start, &i_end)); 265*ec9796c4SHansol Suh PetscCall(VecGetType(x, &vec_type)); 266*ec9796c4SHansol Suh // create scatter for communicating values 267*ec9796c4SHansol Suh PetscCall(VecGetLayout(x, &layout)); 268*ec9796c4SHansol Suh n_recv = 0; 269*ec9796c4SHansol Suh if (user->n_local_comp && i_end < user->n) { 270*ec9796c4SHansol Suh PetscMPIInt rank; 271*ec9796c4SHansol Suh PetscInt index; 272*ec9796c4SHansol Suh 273*ec9796c4SHansol Suh n_recv = 1; 274*ec9796c4SHansol Suh PetscCall(PetscLayoutFindOwnerIndex(layout, i_end, &rank, &index)); 275*ec9796c4SHansol Suh recv.rank = rank; 276*ec9796c4SHansol Suh recv.index = index; 277*ec9796c4SHansol Suh } 278*ec9796c4SHansol Suh PetscCall(PetscSFCreate(user->comm, &user->off_process_scatter)); 279*ec9796c4SHansol Suh PetscCall(PetscSFSetGraph(user->off_process_scatter, user->n_local, n_recv, NULL, PETSC_USE_POINTER, &recv, PETSC_COPY_VALUES)); 280*ec9796c4SHansol Suh PetscCall(VecCreate(user->comm, &user->off_process_values)); 281*ec9796c4SHansol Suh PetscCall(VecSetSizes(user->off_process_values, 1, PETSC_DETERMINE)); 282*ec9796c4SHansol Suh PetscCall(VecSetType(user->off_process_values, vec_type)); 283*ec9796c4SHansol Suh PetscCall(VecZeroEntries(user->off_process_values)); 284*ec9796c4SHansol Suh 285*ec9796c4SHansol Suh // create COO data for writing the gradient 286*ec9796c4SHansol Suh n_coo = user->n_local_comp * 2; 287*ec9796c4SHansol Suh PetscCall(PetscMalloc1(n_coo, &coo_i)); 288*ec9796c4SHansol Suh for (PetscInt c = c_start, k = 0; c < c_end; c++, k += 2) { 289*ec9796c4SHansol Suh PetscInt i = (c / (bs - 1)) * bs + (c % (bs - 1)); 290*ec9796c4SHansol Suh 291*ec9796c4SHansol Suh coo_i[k + 0] = i; 292*ec9796c4SHansol Suh coo_i[k + 1] = i + 1; 293*ec9796c4SHansol Suh } 294*ec9796c4SHansol Suh PetscCall(PetscSFCreate(user->comm, &user->gscatter)); 295*ec9796c4SHansol Suh PetscCall(PetscSFSetGraphLayout(user->gscatter, layout, n_coo, NULL, PETSC_USE_POINTER, coo_i)); 296*ec9796c4SHansol Suh PetscCall(PetscSFSetUp(user->gscatter)); 297*ec9796c4SHansol Suh PetscCall(PetscFree(coo_i)); 298*ec9796c4SHansol Suh PetscCall(VecCreate(user->comm, &user->gvalues)); 299*ec9796c4SHansol Suh PetscCall(VecSetSizes(user->gvalues, n_coo, PETSC_DETERMINE)); 300*ec9796c4SHansol Suh PetscCall(VecSetType(user->gvalues, vec_type)); 301*ec9796c4SHansol Suh PetscFunctionReturn(PETSC_SUCCESS); 302*ec9796c4SHansol Suh } 303*ec9796c4SHansol Suh 304*ec9796c4SHansol Suh #if PetscDefined(USING_CUPMCC) 305*ec9796c4SHansol Suh 306*ec9796c4SHansol Suh #if PetscDefined(USING_NVCC) 307*ec9796c4SHansol Suh typedef cudaStream_t cupmStream_t; 308*ec9796c4SHansol Suh #define PetscCUPMLaunch(...) \ 309*ec9796c4SHansol Suh do { \ 310*ec9796c4SHansol Suh __VA_ARGS__; \ 311*ec9796c4SHansol Suh PetscCallCUDA(cudaGetLastError()); \ 312*ec9796c4SHansol Suh } while (0) 313*ec9796c4SHansol Suh #elif PetscDefined(USING_HCC) 314*ec9796c4SHansol Suh #define PetscCUPMLaunch(...) \ 315*ec9796c4SHansol Suh do { \ 316*ec9796c4SHansol Suh __VA_ARGS__; \ 317*ec9796c4SHansol Suh PetscCallHIP(hipGetLastError()); \ 318*ec9796c4SHansol Suh } while (0) 319*ec9796c4SHansol Suh typedef hipStream_t cupmStream_t; 320*ec9796c4SHansol Suh #endif 321*ec9796c4SHansol Suh 322*ec9796c4SHansol Suh // x: on-process optimization variables 323*ec9796c4SHansol Suh // o: buffer that contains the next optimization variable after the variables on this process 324*ec9796c4SHansol Suh template <typename T> 325*ec9796c4SHansol Suh PETSC_DEVICE_INLINE_DECL static void rosenbrock_for_loop(Rosenbrock r, const PetscScalar x[], const PetscScalar o[], T &&func) noexcept 326*ec9796c4SHansol Suh { 327*ec9796c4SHansol Suh PetscInt idx = blockIdx.x * blockDim.x + threadIdx.x; // 1D grid 328*ec9796c4SHansol Suh PetscInt num_threads = gridDim.x * blockDim.x; 329*ec9796c4SHansol Suh 330*ec9796c4SHansol Suh for (PetscInt c = r.c_start + idx, k = idx; c < r.c_end; c += num_threads, k += num_threads) { 331*ec9796c4SHansol Suh PetscInt i = (c / (r.bs - 1)) * r.bs + (c % (r.bs - 1)); 332*ec9796c4SHansol Suh PetscScalar x_a = x[i - r.i_start]; 333*ec9796c4SHansol Suh PetscScalar x_b = ((i + 1) < r.i_end) ? x[i + 1 - r.i_start] : o[0]; 334*ec9796c4SHansol Suh 335*ec9796c4SHansol Suh func(k, x_a, x_b); 336*ec9796c4SHansol Suh } 337*ec9796c4SHansol Suh return; 338*ec9796c4SHansol Suh } 339*ec9796c4SHansol Suh 340*ec9796c4SHansol Suh PETSC_KERNEL_DECL void RosenbrockObjective_Kernel(Rosenbrock r, const PetscScalar x[], const PetscScalar o[], PetscScalar f_vec[]) 341*ec9796c4SHansol Suh { 342*ec9796c4SHansol Suh rosenbrock_for_loop(r, x, o, [&](PetscInt k, PetscScalar x_a, PetscScalar x_b) { f_vec[k] = RosenbrockObjective(r.alpha, x_a, x_b); }); 343*ec9796c4SHansol Suh } 344*ec9796c4SHansol Suh 345*ec9796c4SHansol Suh PETSC_KERNEL_DECL void RosenbrockGradient_Kernel(Rosenbrock r, const PetscScalar x[], const PetscScalar o[], PetscScalar g[]) 346*ec9796c4SHansol Suh { 347*ec9796c4SHansol Suh rosenbrock_for_loop(r, x, o, [&](PetscInt k, PetscScalar x_a, PetscScalar x_b) { RosenbrockGradient(r.alpha, x_a, x_b, &g[2 * k]); }); 348*ec9796c4SHansol Suh } 349*ec9796c4SHansol Suh 350*ec9796c4SHansol Suh PETSC_KERNEL_DECL void RosenbrockObjectiveGradient_Kernel(Rosenbrock r, const PetscScalar x[], const PetscScalar o[], PetscScalar f_vec[], PetscScalar g[]) 351*ec9796c4SHansol Suh { 352*ec9796c4SHansol Suh rosenbrock_for_loop(r, x, o, [&](PetscInt k, PetscScalar x_a, PetscScalar x_b) { f_vec[k] = RosenbrockObjectiveGradient(r.alpha, x_a, x_b, &g[2 * k]); }); 353*ec9796c4SHansol Suh } 354*ec9796c4SHansol Suh 355*ec9796c4SHansol Suh PETSC_KERNEL_DECL void RosenbrockHessian_Kernel(Rosenbrock r, const PetscScalar x[], const PetscScalar o[], PetscScalar h[]) 356*ec9796c4SHansol Suh { 357*ec9796c4SHansol Suh rosenbrock_for_loop(r, x, o, [&](PetscInt k, PetscScalar x_a, PetscScalar x_b) { RosenbrockHessian(r.alpha, x_a, x_b, &h[4 * k]); }); 358*ec9796c4SHansol Suh } 359*ec9796c4SHansol Suh 360*ec9796c4SHansol Suh static PetscErrorCode RosenbrockObjective_Device(cupmStream_t stream, Rosenbrock r, const PetscScalar x[], const PetscScalar o[], PetscScalar f_vec[]) 361*ec9796c4SHansol Suh { 362*ec9796c4SHansol Suh PetscInt n_comp = r.c_end - r.c_start; 363*ec9796c4SHansol Suh 364*ec9796c4SHansol Suh PetscFunctionBegin; 365*ec9796c4SHansol Suh if (n_comp) PetscCUPMLaunch(RosenbrockObjective_Kernel<<<(n_comp + 255) / 256, 256, 0, stream>>>(r, x, o, f_vec)); 366*ec9796c4SHansol Suh PetscCall(PetscLogGpuFlops(RosenbrockObjectiveFlops * n_comp)); 367*ec9796c4SHansol Suh PetscFunctionReturn(PETSC_SUCCESS); 368*ec9796c4SHansol Suh } 369*ec9796c4SHansol Suh 370*ec9796c4SHansol Suh static PetscErrorCode RosenbrockGradient_Device(cupmStream_t stream, Rosenbrock r, const PetscScalar x[], const PetscScalar o[], PetscScalar g[]) 371*ec9796c4SHansol Suh { 372*ec9796c4SHansol Suh PetscInt n_comp = r.c_end - r.c_start; 373*ec9796c4SHansol Suh 374*ec9796c4SHansol Suh PetscFunctionBegin; 375*ec9796c4SHansol Suh if (n_comp) PetscCUPMLaunch(RosenbrockGradient_Kernel<<<(n_comp + 255) / 256, 256, 0, stream>>>(r, x, o, g)); 376*ec9796c4SHansol Suh PetscCall(PetscLogGpuFlops(RosenbrockGradientFlops * n_comp)); 377*ec9796c4SHansol Suh PetscFunctionReturn(PETSC_SUCCESS); 378*ec9796c4SHansol Suh } 379*ec9796c4SHansol Suh 380*ec9796c4SHansol Suh static PetscErrorCode RosenbrockObjectiveGradient_Device(cupmStream_t stream, Rosenbrock r, const PetscScalar x[], const PetscScalar o[], PetscScalar f_vec[], PetscScalar g[]) 381*ec9796c4SHansol Suh { 382*ec9796c4SHansol Suh PetscInt n_comp = r.c_end - r.c_start; 383*ec9796c4SHansol Suh 384*ec9796c4SHansol Suh PetscFunctionBegin; 385*ec9796c4SHansol Suh if (n_comp) PetscCUPMLaunch(RosenbrockObjectiveGradient_Kernel<<<(n_comp + 255) / 256, 256, 0, stream>>>(r, x, o, f_vec, g)); 386*ec9796c4SHansol Suh PetscCall(PetscLogGpuFlops(RosenbrockObjectiveGradientFlops * n_comp)); 387*ec9796c4SHansol Suh PetscFunctionReturn(PETSC_SUCCESS); 388*ec9796c4SHansol Suh } 389*ec9796c4SHansol Suh 390*ec9796c4SHansol Suh static PetscErrorCode RosenbrockHessian_Device(cupmStream_t stream, Rosenbrock r, const PetscScalar x[], const PetscScalar o[], PetscScalar h[]) 391*ec9796c4SHansol Suh { 392*ec9796c4SHansol Suh PetscInt n_comp = r.c_end - r.c_start; 393*ec9796c4SHansol Suh 394*ec9796c4SHansol Suh PetscFunctionBegin; 395*ec9796c4SHansol Suh if (n_comp) PetscCUPMLaunch(RosenbrockHessian_Kernel<<<(n_comp + 255) / 256, 256, 0, stream>>>(r, x, o, h)); 396*ec9796c4SHansol Suh PetscCall(PetscLogGpuFlops(RosenbrockHessianFlops * n_comp)); 397*ec9796c4SHansol Suh PetscFunctionReturn(PETSC_SUCCESS); 398*ec9796c4SHansol Suh } 399*ec9796c4SHansol Suh #endif 400*ec9796c4SHansol Suh 401*ec9796c4SHansol Suh static PetscErrorCode RosenbrockObjective_Host(Rosenbrock r, const PetscScalar x[], const PetscScalar o[], PetscReal *f) 402*ec9796c4SHansol Suh { 403*ec9796c4SHansol Suh PetscReal _f = 0.0; 404*ec9796c4SHansol Suh 405*ec9796c4SHansol Suh PetscFunctionBegin; 406*ec9796c4SHansol Suh for (PetscInt c = r.c_start; c < r.c_end; c++) { 407*ec9796c4SHansol Suh PetscInt i = (c / (r.bs - 1)) * r.bs + (c % (r.bs - 1)); 408*ec9796c4SHansol Suh PetscScalar x_a = x[i - r.i_start]; 409*ec9796c4SHansol Suh PetscScalar x_b = ((i + 1) < r.i_end) ? x[i + 1 - r.i_start] : o[0]; 410*ec9796c4SHansol Suh 411*ec9796c4SHansol Suh _f += RosenbrockObjective(r.alpha, x_a, x_b); 412*ec9796c4SHansol Suh } 413*ec9796c4SHansol Suh *f = _f; 414*ec9796c4SHansol Suh PetscCall(PetscLogFlops((RosenbrockObjectiveFlops + 1.0) * (r.c_end - r.c_start))); 415*ec9796c4SHansol Suh PetscFunctionReturn(PETSC_SUCCESS); 416*ec9796c4SHansol Suh } 417*ec9796c4SHansol Suh 418*ec9796c4SHansol Suh static PetscErrorCode RosenbrockGradient_Host(Rosenbrock r, const PetscScalar x[], const PetscScalar o[], PetscScalar g[]) 419*ec9796c4SHansol Suh { 420*ec9796c4SHansol Suh PetscFunctionBegin; 421*ec9796c4SHansol Suh for (PetscInt c = r.c_start, k = 0; c < r.c_end; c++, k++) { 422*ec9796c4SHansol Suh PetscInt i = (c / (r.bs - 1)) * r.bs + (c % (r.bs - 1)); 423*ec9796c4SHansol Suh PetscScalar x_a = x[i - r.i_start]; 424*ec9796c4SHansol Suh PetscScalar x_b = ((i + 1) < r.i_end) ? x[i + 1 - r.i_start] : o[0]; 425*ec9796c4SHansol Suh 426*ec9796c4SHansol Suh RosenbrockGradient(r.alpha, x_a, x_b, &g[2 * k]); 427*ec9796c4SHansol Suh } 428*ec9796c4SHansol Suh PetscCall(PetscLogFlops(RosenbrockGradientFlops * (r.c_end - r.c_start))); 429*ec9796c4SHansol Suh PetscFunctionReturn(PETSC_SUCCESS); 430*ec9796c4SHansol Suh } 431*ec9796c4SHansol Suh 432*ec9796c4SHansol Suh static PetscErrorCode RosenbrockObjectiveGradient_Host(Rosenbrock r, const PetscScalar x[], const PetscScalar o[], PetscReal *f, PetscScalar g[]) 433*ec9796c4SHansol Suh { 434*ec9796c4SHansol Suh PetscReal _f = 0.0; 435*ec9796c4SHansol Suh 436*ec9796c4SHansol Suh PetscFunctionBegin; 437*ec9796c4SHansol Suh for (PetscInt c = r.c_start, k = 0; c < r.c_end; c++, k++) { 438*ec9796c4SHansol Suh PetscInt i = (c / (r.bs - 1)) * r.bs + (c % (r.bs - 1)); 439*ec9796c4SHansol Suh PetscScalar x_a = x[i - r.i_start]; 440*ec9796c4SHansol Suh PetscScalar x_b = ((i + 1) < r.i_end) ? x[i + 1 - r.i_start] : o[0]; 441*ec9796c4SHansol Suh 442*ec9796c4SHansol Suh _f += RosenbrockObjectiveGradient(r.alpha, x_a, x_b, &g[2 * k]); 443*ec9796c4SHansol Suh } 444*ec9796c4SHansol Suh *f = _f; 445*ec9796c4SHansol Suh PetscCall(PetscLogFlops(RosenbrockObjectiveGradientFlops * (r.c_end - r.c_start))); 446*ec9796c4SHansol Suh PetscFunctionReturn(PETSC_SUCCESS); 447*ec9796c4SHansol Suh } 448*ec9796c4SHansol Suh 449*ec9796c4SHansol Suh static PetscErrorCode RosenbrockHessian_Host(Rosenbrock r, const PetscScalar x[], const PetscScalar o[], PetscScalar h[]) 450*ec9796c4SHansol Suh { 451*ec9796c4SHansol Suh PetscFunctionBegin; 452*ec9796c4SHansol Suh for (PetscInt c = r.c_start, k = 0; c < r.c_end; c++, k++) { 453*ec9796c4SHansol Suh PetscInt i = (c / (r.bs - 1)) * r.bs + (c % (r.bs - 1)); 454*ec9796c4SHansol Suh PetscScalar x_a = x[i - r.i_start]; 455*ec9796c4SHansol Suh PetscScalar x_b = ((i + 1) < r.i_end) ? x[i + 1 - r.i_start] : o[0]; 456*ec9796c4SHansol Suh 457*ec9796c4SHansol Suh RosenbrockHessian(r.alpha, x_a, x_b, &h[4 * k]); 458*ec9796c4SHansol Suh } 459*ec9796c4SHansol Suh PetscCall(PetscLogFlops(RosenbrockHessianFlops * (r.c_end - r.c_start))); 460*ec9796c4SHansol Suh PetscFunctionReturn(PETSC_SUCCESS); 461*ec9796c4SHansol Suh } 462*ec9796c4SHansol Suh 463*ec9796c4SHansol Suh /* -------------------------------------------------------------------- */ 464*ec9796c4SHansol Suh 465*ec9796c4SHansol Suh static PetscErrorCode FormObjective(Tao tao, Vec X, PetscReal *f, void *ptr) 466*ec9796c4SHansol Suh { 467*ec9796c4SHansol Suh AppCtx user = (AppCtx)ptr; 468*ec9796c4SHansol Suh PetscReal f_local = 0.0; 469*ec9796c4SHansol Suh const PetscScalar *x; 470*ec9796c4SHansol Suh const PetscScalar *o = NULL; 471*ec9796c4SHansol Suh PetscMemType memtype_x; 472*ec9796c4SHansol Suh 473*ec9796c4SHansol Suh PetscFunctionBeginUser; 474*ec9796c4SHansol Suh PetscCall(PetscLogEventBegin(user->event_f, tao, NULL, NULL, NULL)); 475*ec9796c4SHansol Suh PetscCall(VecScatterBegin(user->off_process_scatter, X, user->off_process_values, INSERT_VALUES, SCATTER_FORWARD)); 476*ec9796c4SHansol Suh PetscCall(VecScatterEnd(user->off_process_scatter, X, user->off_process_values, INSERT_VALUES, SCATTER_FORWARD)); 477*ec9796c4SHansol Suh PetscCall(VecGetArrayReadAndMemType(user->off_process_values, &o, NULL)); 478*ec9796c4SHansol Suh PetscCall(VecGetArrayReadAndMemType(X, &x, &memtype_x)); 479*ec9796c4SHansol Suh if (memtype_x == PETSC_MEMTYPE_HOST) { 480*ec9796c4SHansol Suh PetscCall(RosenbrockObjective_Host(user->problem, x, o, &f_local)); 481*ec9796c4SHansol Suh PetscCallMPI(MPI_Allreduce(&f_local, f, 1, MPI_DOUBLE, MPI_SUM, user->comm)); 482*ec9796c4SHansol Suh #if PetscDefined(USING_CUPMCC) 483*ec9796c4SHansol Suh } else if (memtype_x == PETSC_MEMTYPE_DEVICE) { 484*ec9796c4SHansol Suh PetscScalar *_fvec; 485*ec9796c4SHansol Suh PetscScalar f_scalar; 486*ec9796c4SHansol Suh cupmStream_t *stream; 487*ec9796c4SHansol Suh PetscDeviceContext dctx; 488*ec9796c4SHansol Suh 489*ec9796c4SHansol Suh PetscCall(PetscDeviceContextGetCurrentContext(&dctx)); 490*ec9796c4SHansol Suh PetscCall(PetscDeviceContextGetStreamHandle(dctx, (void **)&stream)); 491*ec9796c4SHansol Suh PetscCall(VecGetArrayWriteAndMemType(user->fvector, &_fvec, NULL)); 492*ec9796c4SHansol Suh PetscCall(RosenbrockObjective_Device(*stream, user->problem, x, o, _fvec)); 493*ec9796c4SHansol Suh PetscCall(VecRestoreArrayWriteAndMemType(user->fvector, &_fvec)); 494*ec9796c4SHansol Suh PetscCall(VecSum(user->fvector, &f_scalar)); 495*ec9796c4SHansol Suh *f = PetscRealPart(f_scalar); 496*ec9796c4SHansol Suh #endif 497*ec9796c4SHansol Suh } else SETERRQ(user->comm, PETSC_ERR_SUP, "Unsuported memtype %d", (int)memtype_x); 498*ec9796c4SHansol Suh PetscCall(VecRestoreArrayReadAndMemType(X, &x)); 499*ec9796c4SHansol Suh PetscCall(VecRestoreArrayReadAndMemType(user->off_process_values, &o)); 500*ec9796c4SHansol Suh PetscCall(PetscLogEventEnd(user->event_f, tao, NULL, NULL, NULL)); 501*ec9796c4SHansol Suh PetscFunctionReturn(PETSC_SUCCESS); 502*ec9796c4SHansol Suh } 503*ec9796c4SHansol Suh 504*ec9796c4SHansol Suh static PetscErrorCode FormGradient(Tao tao, Vec X, Vec G, void *ptr) 505*ec9796c4SHansol Suh { 506*ec9796c4SHansol Suh AppCtx user = (AppCtx)ptr; 507*ec9796c4SHansol Suh PetscScalar *g; 508*ec9796c4SHansol Suh const PetscScalar *x; 509*ec9796c4SHansol Suh const PetscScalar *o = NULL; 510*ec9796c4SHansol Suh PetscMemType memtype_x, memtype_g; 511*ec9796c4SHansol Suh 512*ec9796c4SHansol Suh PetscFunctionBeginUser; 513*ec9796c4SHansol Suh PetscCall(PetscLogEventBegin(user->event_g, tao, NULL, NULL, NULL)); 514*ec9796c4SHansol Suh PetscCall(VecScatterBegin(user->off_process_scatter, X, user->off_process_values, INSERT_VALUES, SCATTER_FORWARD)); 515*ec9796c4SHansol Suh PetscCall(VecScatterEnd(user->off_process_scatter, X, user->off_process_values, INSERT_VALUES, SCATTER_FORWARD)); 516*ec9796c4SHansol Suh PetscCall(VecGetArrayReadAndMemType(user->off_process_values, &o, NULL)); 517*ec9796c4SHansol Suh PetscCall(VecGetArrayReadAndMemType(X, &x, &memtype_x)); 518*ec9796c4SHansol Suh PetscCall(VecGetArrayWriteAndMemType(user->gvalues, &g, &memtype_g)); 519*ec9796c4SHansol Suh PetscAssert(memtype_x == memtype_g, user->comm, PETSC_ERR_ARG_INCOMP, "solution vector and gradient must have save memtype"); 520*ec9796c4SHansol Suh if (memtype_x == PETSC_MEMTYPE_HOST) { 521*ec9796c4SHansol Suh PetscCall(RosenbrockGradient_Host(user->problem, x, o, g)); 522*ec9796c4SHansol Suh #if PetscDefined(USING_CUPMCC) 523*ec9796c4SHansol Suh } else if (memtype_x == PETSC_MEMTYPE_DEVICE) { 524*ec9796c4SHansol Suh cupmStream_t *stream; 525*ec9796c4SHansol Suh PetscDeviceContext dctx; 526*ec9796c4SHansol Suh 527*ec9796c4SHansol Suh PetscCall(PetscDeviceContextGetCurrentContext(&dctx)); 528*ec9796c4SHansol Suh PetscCall(PetscDeviceContextGetStreamHandle(dctx, (void **)&stream)); 529*ec9796c4SHansol Suh PetscCall(RosenbrockGradient_Device(*stream, user->problem, x, o, g)); 530*ec9796c4SHansol Suh #endif 531*ec9796c4SHansol Suh } else SETERRQ(user->comm, PETSC_ERR_SUP, "Unsuported memtype %d", (int)memtype_x); 532*ec9796c4SHansol Suh PetscCall(VecRestoreArrayWriteAndMemType(user->gvalues, &g)); 533*ec9796c4SHansol Suh PetscCall(VecRestoreArrayReadAndMemType(X, &x)); 534*ec9796c4SHansol Suh PetscCall(VecRestoreArrayReadAndMemType(user->off_process_values, &o)); 535*ec9796c4SHansol Suh PetscCall(VecZeroEntries(G)); 536*ec9796c4SHansol Suh PetscCall(VecScatterBegin(user->gscatter, user->gvalues, G, ADD_VALUES, SCATTER_REVERSE)); 537*ec9796c4SHansol Suh PetscCall(VecScatterEnd(user->gscatter, user->gvalues, G, ADD_VALUES, SCATTER_REVERSE)); 538*ec9796c4SHansol Suh PetscCall(PetscLogEventEnd(user->event_g, tao, NULL, NULL, NULL)); 539*ec9796c4SHansol Suh PetscFunctionReturn(PETSC_SUCCESS); 540*ec9796c4SHansol Suh } 541*ec9796c4SHansol Suh 542*ec9796c4SHansol Suh /* 543*ec9796c4SHansol Suh FormObjectiveGradient - Evaluates the function, f(X), and gradient, G(X). 544*ec9796c4SHansol Suh 545*ec9796c4SHansol Suh Input Parameters: 546*ec9796c4SHansol Suh . tao - the Tao context 547*ec9796c4SHansol Suh . X - input vector 548*ec9796c4SHansol Suh . ptr - optional user-defined context, as set by TaoSetObjectiveGradient() 549*ec9796c4SHansol Suh 550*ec9796c4SHansol Suh Output Parameters: 551*ec9796c4SHansol Suh . G - vector containing the newly evaluated gradient 552*ec9796c4SHansol Suh . f - function value 553*ec9796c4SHansol Suh 554*ec9796c4SHansol Suh Note: 555*ec9796c4SHansol Suh Some optimization methods ask for the function and the gradient evaluation 556*ec9796c4SHansol Suh at the same time. Evaluating both at once may be more efficient that 557*ec9796c4SHansol Suh evaluating each separately. 558*ec9796c4SHansol Suh */ 559*ec9796c4SHansol Suh static PetscErrorCode FormObjectiveGradient(Tao tao, Vec X, PetscReal *f, Vec G, void *ptr) 560*ec9796c4SHansol Suh { 561*ec9796c4SHansol Suh AppCtx user = (AppCtx)ptr; 562*ec9796c4SHansol Suh PetscReal f_local = 0.0; 563*ec9796c4SHansol Suh PetscScalar *g; 564*ec9796c4SHansol Suh const PetscScalar *x; 565*ec9796c4SHansol Suh const PetscScalar *o = NULL; 566*ec9796c4SHansol Suh PetscMemType memtype_x, memtype_g; 567*ec9796c4SHansol Suh 568*ec9796c4SHansol Suh PetscFunctionBeginUser; 569*ec9796c4SHansol Suh PetscCall(PetscLogEventBegin(user->event_fg, tao, NULL, NULL, NULL)); 570*ec9796c4SHansol Suh PetscCall(VecScatterBegin(user->off_process_scatter, X, user->off_process_values, INSERT_VALUES, SCATTER_FORWARD)); 571*ec9796c4SHansol Suh PetscCall(VecScatterEnd(user->off_process_scatter, X, user->off_process_values, INSERT_VALUES, SCATTER_FORWARD)); 572*ec9796c4SHansol Suh PetscCall(VecGetArrayReadAndMemType(user->off_process_values, &o, NULL)); 573*ec9796c4SHansol Suh PetscCall(VecGetArrayReadAndMemType(X, &x, &memtype_x)); 574*ec9796c4SHansol Suh PetscCall(VecGetArrayWriteAndMemType(user->gvalues, &g, &memtype_g)); 575*ec9796c4SHansol Suh PetscAssert(memtype_x == memtype_g, user->comm, PETSC_ERR_ARG_INCOMP, "solution vector and gradient must have save memtype"); 576*ec9796c4SHansol Suh if (memtype_x == PETSC_MEMTYPE_HOST) { 577*ec9796c4SHansol Suh PetscCall(RosenbrockObjectiveGradient_Host(user->problem, x, o, &f_local, g)); 578*ec9796c4SHansol Suh PetscCallMPI(MPI_Allreduce((void *)&f_local, (void *)f, 1, MPI_DOUBLE, MPI_SUM, PETSC_COMM_WORLD)); 579*ec9796c4SHansol Suh #if PetscDefined(USING_CUPMCC) 580*ec9796c4SHansol Suh } else if (memtype_x == PETSC_MEMTYPE_DEVICE) { 581*ec9796c4SHansol Suh PetscScalar *_fvec; 582*ec9796c4SHansol Suh PetscScalar f_scalar; 583*ec9796c4SHansol Suh cupmStream_t *stream; 584*ec9796c4SHansol Suh PetscDeviceContext dctx; 585*ec9796c4SHansol Suh 586*ec9796c4SHansol Suh PetscCall(PetscDeviceContextGetCurrentContext(&dctx)); 587*ec9796c4SHansol Suh PetscCall(PetscDeviceContextGetStreamHandle(dctx, (void **)&stream)); 588*ec9796c4SHansol Suh PetscCall(VecGetArrayWriteAndMemType(user->fvector, &_fvec, NULL)); 589*ec9796c4SHansol Suh PetscCall(RosenbrockObjectiveGradient_Device(*stream, user->problem, x, o, _fvec, g)); 590*ec9796c4SHansol Suh PetscCall(VecRestoreArrayWriteAndMemType(user->fvector, &_fvec)); 591*ec9796c4SHansol Suh PetscCall(VecSum(user->fvector, &f_scalar)); 592*ec9796c4SHansol Suh *f = PetscRealPart(f_scalar); 593*ec9796c4SHansol Suh #endif 594*ec9796c4SHansol Suh } else SETERRQ(user->comm, PETSC_ERR_SUP, "Unsuported memtype %d", (int)memtype_x); 595*ec9796c4SHansol Suh 596*ec9796c4SHansol Suh PetscCall(VecRestoreArrayWriteAndMemType(user->gvalues, &g)); 597*ec9796c4SHansol Suh PetscCall(VecRestoreArrayReadAndMemType(X, &x)); 598*ec9796c4SHansol Suh PetscCall(VecRestoreArrayReadAndMemType(user->off_process_values, &o)); 599*ec9796c4SHansol Suh PetscCall(VecZeroEntries(G)); 600*ec9796c4SHansol Suh PetscCall(VecScatterBegin(user->gscatter, user->gvalues, G, ADD_VALUES, SCATTER_REVERSE)); 601*ec9796c4SHansol Suh PetscCall(VecScatterEnd(user->gscatter, user->gvalues, G, ADD_VALUES, SCATTER_REVERSE)); 602*ec9796c4SHansol Suh PetscCall(PetscLogEventEnd(user->event_fg, tao, NULL, NULL, NULL)); 603*ec9796c4SHansol Suh PetscFunctionReturn(PETSC_SUCCESS); 604*ec9796c4SHansol Suh } 605*ec9796c4SHansol Suh 606*ec9796c4SHansol Suh /* ------------------------------------------------------------------- */ 607*ec9796c4SHansol Suh /* 608*ec9796c4SHansol Suh FormHessian - Evaluates Hessian matrix. 609*ec9796c4SHansol Suh 610*ec9796c4SHansol Suh Input Parameters: 611*ec9796c4SHansol Suh . tao - the Tao context 612*ec9796c4SHansol Suh . x - input vector 613*ec9796c4SHansol Suh . ptr - optional user-defined context, as set by TaoSetHessian() 614*ec9796c4SHansol Suh 615*ec9796c4SHansol Suh Output Parameters: 616*ec9796c4SHansol Suh . H - Hessian matrix 617*ec9796c4SHansol Suh 618*ec9796c4SHansol Suh Note: Providing the Hessian may not be necessary. Only some solvers 619*ec9796c4SHansol Suh require this matrix. 620*ec9796c4SHansol Suh */ 621*ec9796c4SHansol Suh static PetscErrorCode FormHessian(Tao tao, Vec X, Mat H, Mat Hpre, void *ptr) 622*ec9796c4SHansol Suh { 623*ec9796c4SHansol Suh AppCtx user = (AppCtx)ptr; 624*ec9796c4SHansol Suh PetscScalar *h; 625*ec9796c4SHansol Suh const PetscScalar *x; 626*ec9796c4SHansol Suh const PetscScalar *o = NULL; 627*ec9796c4SHansol Suh PetscMemType memtype_x, memtype_h; 628*ec9796c4SHansol Suh 629*ec9796c4SHansol Suh PetscFunctionBeginUser; 630*ec9796c4SHansol Suh PetscCall(VecScatterBegin(user->off_process_scatter, X, user->off_process_values, INSERT_VALUES, SCATTER_FORWARD)); 631*ec9796c4SHansol Suh PetscCall(VecScatterEnd(user->off_process_scatter, X, user->off_process_values, INSERT_VALUES, SCATTER_FORWARD)); 632*ec9796c4SHansol Suh PetscCall(VecGetArrayReadAndMemType(user->off_process_values, &o, NULL)); 633*ec9796c4SHansol Suh PetscCall(VecGetArrayReadAndMemType(X, &x, &memtype_x)); 634*ec9796c4SHansol Suh PetscCall(VecGetArrayWriteAndMemType(user->Hvalues, &h, &memtype_h)); 635*ec9796c4SHansol Suh PetscAssert(memtype_x == memtype_h, user->comm, PETSC_ERR_ARG_INCOMP, "solution vector and hessian must have save memtype"); 636*ec9796c4SHansol Suh if (memtype_x == PETSC_MEMTYPE_HOST) { 637*ec9796c4SHansol Suh PetscCall(RosenbrockHessian_Host(user->problem, x, o, h)); 638*ec9796c4SHansol Suh #if PetscDefined(USING_CUPMCC) 639*ec9796c4SHansol Suh } else if (memtype_x == PETSC_MEMTYPE_DEVICE) { 640*ec9796c4SHansol Suh cupmStream_t *stream; 641*ec9796c4SHansol Suh PetscDeviceContext dctx; 642*ec9796c4SHansol Suh 643*ec9796c4SHansol Suh PetscCall(PetscDeviceContextGetCurrentContext(&dctx)); 644*ec9796c4SHansol Suh PetscCall(PetscDeviceContextGetStreamHandle(dctx, (void **)&stream)); 645*ec9796c4SHansol Suh PetscCall(RosenbrockHessian_Device(*stream, user->problem, x, o, h)); 646*ec9796c4SHansol Suh #endif 647*ec9796c4SHansol Suh } else SETERRQ(user->comm, PETSC_ERR_SUP, "Unsuported memtype %d", (int)memtype_x); 648*ec9796c4SHansol Suh 649*ec9796c4SHansol Suh PetscCall(MatSetValuesCOO(H, h, INSERT_VALUES)); 650*ec9796c4SHansol Suh PetscCall(VecRestoreArrayWriteAndMemType(user->Hvalues, &h)); 651*ec9796c4SHansol Suh 652*ec9796c4SHansol Suh PetscCall(VecRestoreArrayReadAndMemType(X, &x)); 653*ec9796c4SHansol Suh PetscCall(VecRestoreArrayReadAndMemType(user->off_process_values, &o)); 654*ec9796c4SHansol Suh 655*ec9796c4SHansol Suh if (Hpre != H) PetscCall(MatCopy(H, Hpre, SAME_NONZERO_PATTERN)); 656*ec9796c4SHansol Suh PetscFunctionReturn(PETSC_SUCCESS); 657*ec9796c4SHansol Suh } 658*ec9796c4SHansol Suh 659*ec9796c4SHansol Suh static PetscErrorCode TestLMVM(Tao tao) 660*ec9796c4SHansol Suh { 661*ec9796c4SHansol Suh KSP ksp; 662*ec9796c4SHansol Suh PC pc; 663*ec9796c4SHansol Suh PetscBool is_lmvm; 664*ec9796c4SHansol Suh 665*ec9796c4SHansol Suh PetscFunctionBegin; 666*ec9796c4SHansol Suh PetscCall(TaoGetKSP(tao, &ksp)); 667*ec9796c4SHansol Suh if (!ksp) PetscFunctionReturn(PETSC_SUCCESS); 668*ec9796c4SHansol Suh PetscCall(KSPGetPC(ksp, &pc)); 669*ec9796c4SHansol Suh PetscCall(PetscObjectTypeCompare((PetscObject)pc, PCLMVM, &is_lmvm)); 670*ec9796c4SHansol Suh if (is_lmvm) { 671*ec9796c4SHansol Suh Mat M; 672*ec9796c4SHansol Suh Vec in, out, out2; 673*ec9796c4SHansol Suh PetscReal mult_solve_dist; 674*ec9796c4SHansol Suh Vec x; 675*ec9796c4SHansol Suh 676*ec9796c4SHansol Suh PetscCall(PCLMVMGetMatLMVM(pc, &M)); 677*ec9796c4SHansol Suh PetscCall(TaoGetSolution(tao, &x)); 678*ec9796c4SHansol Suh PetscCall(VecDuplicate(x, &in)); 679*ec9796c4SHansol Suh PetscCall(VecDuplicate(x, &out)); 680*ec9796c4SHansol Suh PetscCall(VecDuplicate(x, &out2)); 681*ec9796c4SHansol Suh PetscCall(VecSetRandom(in, NULL)); 682*ec9796c4SHansol Suh PetscCall(MatMult(M, in, out)); 683*ec9796c4SHansol Suh PetscCall(MatSolve(M, out, out2)); 684*ec9796c4SHansol Suh 685*ec9796c4SHansol Suh PetscCall(VecAXPY(out2, -1.0, in)); 686*ec9796c4SHansol Suh PetscCall(VecNorm(out2, NORM_2, &mult_solve_dist)); 687*ec9796c4SHansol Suh if (mult_solve_dist < 1.e-11) { 688*ec9796c4SHansol Suh PetscCall(PetscPrintf(PetscObjectComm((PetscObject)tao), "Inverse error of LMVM MatMult and MatSolve: < 1.e-11\n")); 689*ec9796c4SHansol Suh } else if (mult_solve_dist < 1.e-6) { 690*ec9796c4SHansol Suh PetscCall(PetscPrintf(PetscObjectComm((PetscObject)tao), "Inverse error of LMVM MatMult and MatSolve: < 1.e-6\n")); 691*ec9796c4SHansol Suh } else { 692*ec9796c4SHansol Suh PetscCall(PetscPrintf(PetscObjectComm((PetscObject)tao), "Inverse error of LMVM MatMult and MatSolve is not small: %e\n", (double)mult_solve_dist)); 693*ec9796c4SHansol Suh } 694*ec9796c4SHansol Suh PetscCall(VecDestroy(&in)); 695*ec9796c4SHansol Suh PetscCall(VecDestroy(&out)); 696*ec9796c4SHansol Suh PetscCall(VecDestroy(&out2)); 697*ec9796c4SHansol Suh } 698*ec9796c4SHansol Suh PetscFunctionReturn(PETSC_SUCCESS); 699*ec9796c4SHansol Suh } 700*ec9796c4SHansol Suh 701*ec9796c4SHansol Suh static PetscErrorCode RosenbrockMain(void) 702*ec9796c4SHansol Suh { 703*ec9796c4SHansol Suh Vec x; /* solution vector */ 704*ec9796c4SHansol Suh Vec g; /* gradient vector */ 705*ec9796c4SHansol Suh Mat H; /* Hessian matrix */ 706*ec9796c4SHansol Suh Tao tao; /* Tao solver context */ 707*ec9796c4SHansol Suh AppCtx user; /* user-defined application context */ 708*ec9796c4SHansol Suh PetscLogStage solve; 709*ec9796c4SHansol Suh 710*ec9796c4SHansol Suh /* Initialize TAO and PETSc */ 711*ec9796c4SHansol Suh PetscFunctionBegin; 712*ec9796c4SHansol Suh PetscCall(PetscLogStageRegister("Rosenbrock solve", &solve)); 713*ec9796c4SHansol Suh 714*ec9796c4SHansol Suh PetscCall(AppCtxCreate(PETSC_COMM_WORLD, &user)); 715*ec9796c4SHansol Suh PetscCall(CreateHessian(user, &H)); 716*ec9796c4SHansol Suh PetscCall(CreateVectors(user, H, &x, &g)); 717*ec9796c4SHansol Suh 718*ec9796c4SHansol Suh /* The TAO code begins here */ 719*ec9796c4SHansol Suh 720*ec9796c4SHansol Suh PetscCall(TaoCreate(user->comm, &tao)); 721*ec9796c4SHansol Suh PetscCall(VecZeroEntries(x)); 722*ec9796c4SHansol Suh PetscCall(TaoSetSolution(tao, x)); 723*ec9796c4SHansol Suh 724*ec9796c4SHansol Suh /* Set routines for function, gradient, hessian evaluation */ 725*ec9796c4SHansol Suh PetscCall(TaoSetObjective(tao, FormObjective, user)); 726*ec9796c4SHansol Suh PetscCall(TaoSetObjectiveAndGradient(tao, g, FormObjectiveGradient, user)); 727*ec9796c4SHansol Suh PetscCall(TaoSetGradient(tao, g, FormGradient, user)); 728*ec9796c4SHansol Suh PetscCall(TaoSetHessian(tao, H, H, FormHessian, user)); 729*ec9796c4SHansol Suh 730*ec9796c4SHansol Suh PetscCall(TaoSetFromOptions(tao)); 731*ec9796c4SHansol Suh 732*ec9796c4SHansol Suh /* SOLVE THE APPLICATION */ 733*ec9796c4SHansol Suh PetscCall(PetscLogStagePush(solve)); 734*ec9796c4SHansol Suh PetscCall(TaoSolve(tao)); 735*ec9796c4SHansol Suh PetscCall(PetscLogStagePop()); 736*ec9796c4SHansol Suh 737*ec9796c4SHansol Suh if (user->test_lmvm) PetscCall(TestLMVM(tao)); 738*ec9796c4SHansol Suh 739*ec9796c4SHansol Suh PetscCall(TaoDestroy(&tao)); 740*ec9796c4SHansol Suh PetscCall(VecDestroy(&g)); 741*ec9796c4SHansol Suh PetscCall(VecDestroy(&x)); 742*ec9796c4SHansol Suh PetscCall(MatDestroy(&H)); 743*ec9796c4SHansol Suh PetscCall(AppCtxDestroy(&user)); 744*ec9796c4SHansol Suh PetscFunctionReturn(PETSC_SUCCESS); 745*ec9796c4SHansol Suh } 746