1 2 #include <petsc/private/matimpl.h> /*I "petscmat.h" I*/ 3 4 typedef struct { 5 IS isrow,iscol; /* rows and columns in submatrix, only used to check consistency */ 6 Vec left,right; /* optional scaling */ 7 Vec olwork,orwork; /* work vectors outside the scatters, only touched by PreScale and only created if needed*/ 8 Vec lwork,rwork; /* work vectors inside the scatters */ 9 VecScatter lrestrict,rprolong; 10 Mat A; 11 PetscScalar scale; 12 } Mat_SubVirtual; 13 14 static PetscErrorCode PreScaleLeft(Mat N,Vec x,Vec *xx) 15 { 16 Mat_SubVirtual *Na = (Mat_SubVirtual*)N->data; 17 PetscErrorCode ierr; 18 19 PetscFunctionBegin; 20 if (!Na->left) { 21 *xx = x; 22 } else { 23 if (!Na->olwork) { 24 ierr = VecDuplicate(Na->left,&Na->olwork);CHKERRQ(ierr); 25 } 26 ierr = VecPointwiseMult(Na->olwork,x,Na->left);CHKERRQ(ierr); 27 *xx = Na->olwork; 28 } 29 PetscFunctionReturn(0); 30 } 31 32 static PetscErrorCode PreScaleRight(Mat N,Vec x,Vec *xx) 33 { 34 Mat_SubVirtual *Na = (Mat_SubVirtual*)N->data; 35 PetscErrorCode ierr; 36 37 PetscFunctionBegin; 38 if (!Na->right) { 39 *xx = x; 40 } else { 41 if (!Na->orwork) { 42 ierr = VecDuplicate(Na->right,&Na->orwork);CHKERRQ(ierr); 43 } 44 ierr = VecPointwiseMult(Na->orwork,x,Na->right);CHKERRQ(ierr); 45 *xx = Na->orwork; 46 } 47 PetscFunctionReturn(0); 48 } 49 50 static PetscErrorCode PostScaleLeft(Mat N,Vec x) 51 { 52 Mat_SubVirtual *Na = (Mat_SubVirtual*)N->data; 53 PetscErrorCode ierr; 54 55 PetscFunctionBegin; 56 if (Na->left) { 57 ierr = VecPointwiseMult(x,x,Na->left);CHKERRQ(ierr); 58 } 59 PetscFunctionReturn(0); 60 } 61 62 static PetscErrorCode PostScaleRight(Mat N,Vec x) 63 { 64 Mat_SubVirtual *Na = (Mat_SubVirtual*)N->data; 65 PetscErrorCode ierr; 66 67 PetscFunctionBegin; 68 if (Na->right) { 69 ierr = VecPointwiseMult(x,x,Na->right);CHKERRQ(ierr); 70 } 71 PetscFunctionReturn(0); 72 } 73 74 static PetscErrorCode MatShift_SubMatrix(Mat N,PetscScalar shift) 75 { 76 Mat_SubVirtual *Na = (Mat_SubVirtual*)N->data; 77 PetscErrorCode ierr; 78 79 PetscFunctionBegin; 80 ierr = MatShift(Na->A,shift);CHKERRQ(ierr); 81 PetscFunctionReturn(0); 82 } 83 84 static PetscErrorCode MatScale_SubMatrix(Mat N,PetscScalar scale) 85 { 86 Mat_SubVirtual *Na = (Mat_SubVirtual*)N->data; 87 88 PetscFunctionBegin; 89 Na->scale *= scale; 90 PetscFunctionReturn(0); 91 } 92 93 static PetscErrorCode MatDiagonalScale_SubMatrix(Mat N,Vec left,Vec right) 94 { 95 Mat_SubVirtual *Na = (Mat_SubVirtual*)N->data; 96 PetscErrorCode ierr; 97 98 PetscFunctionBegin; 99 if (left) { 100 if (!Na->left) { 101 ierr = VecDuplicate(left,&Na->left);CHKERRQ(ierr); 102 ierr = VecCopy(left,Na->left);CHKERRQ(ierr); 103 } else { 104 ierr = VecPointwiseMult(Na->left,left,Na->left);CHKERRQ(ierr); 105 } 106 } 107 if (right) { 108 if (!Na->right) { 109 ierr = VecDuplicate(right,&Na->right);CHKERRQ(ierr); 110 ierr = VecCopy(right,Na->right);CHKERRQ(ierr); 111 } else { 112 ierr = VecPointwiseMult(Na->right,right,Na->right);CHKERRQ(ierr); 113 } 114 } 115 PetscFunctionReturn(0); 116 } 117 118 static PetscErrorCode MatSolve_SubMatrix(Mat N,Vec x,Vec y) 119 { 120 Mat_SubVirtual *Na = (Mat_SubVirtual*)N->data; 121 Vec lwork, rwork, xx = 0; 122 PetscErrorCode ierr; 123 124 PetscFunctionBegin; 125 ierr = PreScaleLeft(N,x,&xx);CHKERRQ(ierr); 126 ierr = VecZeroEntries(Na->lwork);CHKERRQ(ierr); 127 ierr = VecGetSubVector(Na->lwork, Na->iscol, &lwork);CHKERRQ(ierr); 128 ierr = VecCopy(x, lwork);CHKERRQ(ierr); 129 ierr = VecRestoreSubVector(Na->lwork, Na->iscol, &lwork);CHKERRQ(ierr); 130 ierr = MatSolve(Na->A,Na->lwork,Na->rwork);CHKERRQ(ierr); 131 ierr = VecGetSubVector(Na->rwork, Na->isrow, &rwork);CHKERRQ(ierr); 132 ierr = VecCopy(rwork, y);CHKERRQ(ierr); 133 ierr = VecRestoreSubVector(Na->rwork, Na->isrow, &rwork);CHKERRQ(ierr); 134 ierr = PostScaleRight(N,y);CHKERRQ(ierr); 135 ierr = VecScale(y,Na->scale);CHKERRQ(ierr); 136 PetscFunctionReturn(0); 137 } 138 139 static PetscErrorCode MatMult_SubMatrix(Mat N,Vec x,Vec y) 140 { 141 Mat_SubVirtual *Na = (Mat_SubVirtual*)N->data; 142 Vec xx = 0; 143 PetscErrorCode ierr; 144 145 PetscFunctionBegin; 146 ierr = PreScaleRight(N,x,&xx);CHKERRQ(ierr); 147 ierr = VecZeroEntries(Na->rwork);CHKERRQ(ierr); 148 ierr = VecScatterBegin(Na->rprolong,xx,Na->rwork,INSERT_VALUES,SCATTER_FORWARD);CHKERRQ(ierr); 149 ierr = VecScatterEnd (Na->rprolong,xx,Na->rwork,INSERT_VALUES,SCATTER_FORWARD);CHKERRQ(ierr); 150 ierr = MatMult(Na->A,Na->rwork,Na->lwork);CHKERRQ(ierr); 151 ierr = VecScatterBegin(Na->lrestrict,Na->lwork,y,INSERT_VALUES,SCATTER_FORWARD);CHKERRQ(ierr); 152 ierr = VecScatterEnd (Na->lrestrict,Na->lwork,y,INSERT_VALUES,SCATTER_FORWARD);CHKERRQ(ierr); 153 ierr = PostScaleLeft(N,y);CHKERRQ(ierr); 154 ierr = VecScale(y,Na->scale);CHKERRQ(ierr); 155 PetscFunctionReturn(0); 156 } 157 158 static PetscErrorCode MatMultAdd_SubMatrix(Mat N,Vec v1,Vec v2,Vec v3) 159 { 160 Mat_SubVirtual *Na = (Mat_SubVirtual*)N->data; 161 Vec xx = 0; 162 PetscErrorCode ierr; 163 164 PetscFunctionBegin; 165 ierr = PreScaleRight(N,v1,&xx);CHKERRQ(ierr); 166 ierr = VecZeroEntries(Na->rwork);CHKERRQ(ierr); 167 ierr = VecScatterBegin(Na->rprolong,xx,Na->rwork,INSERT_VALUES,SCATTER_FORWARD);CHKERRQ(ierr); 168 ierr = VecScatterEnd (Na->rprolong,xx,Na->rwork,INSERT_VALUES,SCATTER_FORWARD);CHKERRQ(ierr); 169 ierr = MatMult(Na->A,Na->rwork,Na->lwork);CHKERRQ(ierr); 170 if (v2 == v3) { 171 if (Na->scale == (PetscScalar)1.0 && !Na->left) { 172 ierr = VecScatterBegin(Na->lrestrict,Na->lwork,v3,ADD_VALUES,SCATTER_FORWARD);CHKERRQ(ierr); 173 ierr = VecScatterEnd (Na->lrestrict,Na->lwork,v3,ADD_VALUES,SCATTER_FORWARD);CHKERRQ(ierr); 174 } else { 175 if (!Na->olwork) {ierr = VecDuplicate(v3,&Na->olwork);CHKERRQ(ierr);} 176 ierr = VecScatterBegin(Na->lrestrict,Na->lwork,Na->olwork,INSERT_VALUES,SCATTER_FORWARD);CHKERRQ(ierr); 177 ierr = VecScatterEnd (Na->lrestrict,Na->lwork,Na->olwork,INSERT_VALUES,SCATTER_FORWARD);CHKERRQ(ierr); 178 ierr = PostScaleLeft(N,Na->olwork);CHKERRQ(ierr); 179 ierr = VecAXPY(v3,Na->scale,Na->olwork);CHKERRQ(ierr); 180 } 181 } else { 182 ierr = VecScatterBegin(Na->lrestrict,Na->lwork,v3,INSERT_VALUES,SCATTER_FORWARD);CHKERRQ(ierr); 183 ierr = VecScatterEnd (Na->lrestrict,Na->lwork,v3,INSERT_VALUES,SCATTER_FORWARD);CHKERRQ(ierr); 184 ierr = PostScaleLeft(N,v3);CHKERRQ(ierr); 185 ierr = VecAYPX(v3,Na->scale,v2);CHKERRQ(ierr); 186 } 187 PetscFunctionReturn(0); 188 } 189 190 static PetscErrorCode MatMultTranspose_SubMatrix(Mat N,Vec x,Vec y) 191 { 192 Mat_SubVirtual *Na = (Mat_SubVirtual*)N->data; 193 Vec xx = 0; 194 PetscErrorCode ierr; 195 196 PetscFunctionBegin; 197 ierr = PreScaleLeft(N,x,&xx);CHKERRQ(ierr); 198 ierr = VecZeroEntries(Na->lwork);CHKERRQ(ierr); 199 ierr = VecScatterBegin(Na->lrestrict,xx,Na->lwork,INSERT_VALUES,SCATTER_REVERSE);CHKERRQ(ierr); 200 ierr = VecScatterEnd (Na->lrestrict,xx,Na->lwork,INSERT_VALUES,SCATTER_REVERSE);CHKERRQ(ierr); 201 ierr = MatMultTranspose(Na->A,Na->lwork,Na->rwork);CHKERRQ(ierr); 202 ierr = VecScatterBegin(Na->rprolong,Na->rwork,y,INSERT_VALUES,SCATTER_REVERSE);CHKERRQ(ierr); 203 ierr = VecScatterEnd (Na->rprolong,Na->rwork,y,INSERT_VALUES,SCATTER_REVERSE);CHKERRQ(ierr); 204 ierr = PostScaleRight(N,y);CHKERRQ(ierr); 205 ierr = VecScale(y,Na->scale);CHKERRQ(ierr); 206 PetscFunctionReturn(0); 207 } 208 209 static PetscErrorCode MatMultTransposeAdd_SubMatrix(Mat N,Vec v1,Vec v2,Vec v3) 210 { 211 Mat_SubVirtual *Na = (Mat_SubVirtual*)N->data; 212 Vec xx = 0; 213 PetscErrorCode ierr; 214 215 PetscFunctionBegin; 216 ierr = PreScaleLeft(N,v1,&xx);CHKERRQ(ierr); 217 ierr = VecZeroEntries(Na->lwork);CHKERRQ(ierr); 218 ierr = VecScatterBegin(Na->lrestrict,xx,Na->lwork,INSERT_VALUES,SCATTER_REVERSE);CHKERRQ(ierr); 219 ierr = VecScatterEnd (Na->lrestrict,xx,Na->lwork,INSERT_VALUES,SCATTER_REVERSE);CHKERRQ(ierr); 220 ierr = MatMultTranspose(Na->A,Na->lwork,Na->rwork);CHKERRQ(ierr); 221 if (v2 == v3) { 222 if (Na->scale == (PetscScalar)1.0 && !Na->right) { 223 ierr = VecScatterBegin(Na->rprolong,Na->rwork,v3,ADD_VALUES,SCATTER_REVERSE);CHKERRQ(ierr); 224 ierr = VecScatterEnd (Na->rprolong,Na->rwork,v3,ADD_VALUES,SCATTER_REVERSE);CHKERRQ(ierr); 225 } else { 226 if (!Na->orwork) {ierr = VecDuplicate(v3,&Na->orwork);CHKERRQ(ierr);} 227 ierr = VecScatterBegin(Na->rprolong,Na->rwork,Na->orwork,INSERT_VALUES,SCATTER_REVERSE);CHKERRQ(ierr); 228 ierr = VecScatterEnd (Na->rprolong,Na->rwork,Na->orwork,INSERT_VALUES,SCATTER_REVERSE);CHKERRQ(ierr); 229 ierr = PostScaleRight(N,Na->orwork);CHKERRQ(ierr); 230 ierr = VecAXPY(v3,Na->scale,Na->orwork);CHKERRQ(ierr); 231 } 232 } else { 233 ierr = VecScatterBegin(Na->rprolong,Na->rwork,v3,INSERT_VALUES,SCATTER_REVERSE);CHKERRQ(ierr); 234 ierr = VecScatterEnd (Na->rprolong,Na->rwork,v3,INSERT_VALUES,SCATTER_REVERSE);CHKERRQ(ierr); 235 ierr = PostScaleRight(N,v3);CHKERRQ(ierr); 236 ierr = VecAYPX(v3,Na->scale,v2);CHKERRQ(ierr); 237 } 238 PetscFunctionReturn(0); 239 } 240 241 static PetscErrorCode MatDestroy_SubMatrix(Mat N) 242 { 243 Mat_SubVirtual *Na = (Mat_SubVirtual*)N->data; 244 PetscErrorCode ierr; 245 246 PetscFunctionBegin; 247 ierr = ISDestroy(&Na->isrow);CHKERRQ(ierr); 248 ierr = ISDestroy(&Na->iscol);CHKERRQ(ierr); 249 ierr = VecDestroy(&Na->left);CHKERRQ(ierr); 250 ierr = VecDestroy(&Na->right);CHKERRQ(ierr); 251 ierr = VecDestroy(&Na->olwork);CHKERRQ(ierr); 252 ierr = VecDestroy(&Na->orwork);CHKERRQ(ierr); 253 ierr = VecDestroy(&Na->lwork);CHKERRQ(ierr); 254 ierr = VecDestroy(&Na->rwork);CHKERRQ(ierr); 255 ierr = VecScatterDestroy(&Na->lrestrict);CHKERRQ(ierr); 256 ierr = VecScatterDestroy(&Na->rprolong);CHKERRQ(ierr); 257 ierr = MatDestroy(&Na->A);CHKERRQ(ierr); 258 ierr = PetscFree(N->data);CHKERRQ(ierr); 259 PetscFunctionReturn(0); 260 } 261 262 /*@ 263 MatCreateSubMatrixVirtual - Creates a virtual matrix that acts as a submatrix 264 265 Collective on Mat 266 267 Input Parameters: 268 + A - matrix that we will extract a submatrix of 269 . isrow - rows to be present in the submatrix 270 - iscol - columns to be present in the submatrix 271 272 Output Parameters: 273 . newmat - new matrix 274 275 Level: developer 276 277 Notes: 278 Most will use MatCreateSubMatrix which provides a more efficient representation if it is available. 279 280 .seealso: MatCreateSubMatrix(), MatSubMatrixVirtualUpdate() 281 @*/ 282 PetscErrorCode MatCreateSubMatrixVirtual(Mat A,IS isrow,IS iscol,Mat *newmat) 283 { 284 Vec left,right; 285 PetscInt m,n; 286 Mat N; 287 Mat_SubVirtual *Na; 288 PetscErrorCode ierr; 289 290 PetscFunctionBegin; 291 PetscValidHeaderSpecific(A,MAT_CLASSID,1); 292 PetscValidHeaderSpecific(isrow,IS_CLASSID,2); 293 PetscValidHeaderSpecific(iscol,IS_CLASSID,3); 294 PetscValidPointer(newmat,4); 295 *newmat = 0; 296 297 ierr = MatCreate(PetscObjectComm((PetscObject)A),&N);CHKERRQ(ierr); 298 ierr = ISGetLocalSize(isrow,&m);CHKERRQ(ierr); 299 ierr = ISGetLocalSize(iscol,&n);CHKERRQ(ierr); 300 ierr = MatSetSizes(N,m,n,PETSC_DETERMINE,PETSC_DETERMINE);CHKERRQ(ierr); 301 ierr = PetscObjectChangeTypeName((PetscObject)N,MATSUBMATRIX);CHKERRQ(ierr); 302 303 ierr = PetscNewLog(N,&Na);CHKERRQ(ierr); 304 N->data = (void*)Na; 305 ierr = PetscObjectReference((PetscObject)A);CHKERRQ(ierr); 306 ierr = PetscObjectReference((PetscObject)isrow);CHKERRQ(ierr); 307 ierr = PetscObjectReference((PetscObject)iscol);CHKERRQ(ierr); 308 Na->A = A; 309 Na->isrow = isrow; 310 Na->iscol = iscol; 311 Na->scale = 1.0; 312 313 N->ops->destroy = MatDestroy_SubMatrix; 314 N->ops->solve = MatSolve_SubMatrix; 315 N->ops->mult = MatMult_SubMatrix; 316 N->ops->multadd = MatMultAdd_SubMatrix; 317 N->ops->multtranspose = MatMultTranspose_SubMatrix; 318 N->ops->multtransposeadd = MatMultTransposeAdd_SubMatrix; 319 N->ops->scale = MatScale_SubMatrix; 320 N->ops->diagonalscale = MatDiagonalScale_SubMatrix; 321 if (A->ops->shift) N->ops->shift = MatShift_SubMatrix; 322 323 ierr = MatSetBlockSizesFromMats(N,A,A);CHKERRQ(ierr); 324 ierr = PetscLayoutSetUp(N->rmap);CHKERRQ(ierr); 325 ierr = PetscLayoutSetUp(N->cmap);CHKERRQ(ierr); 326 327 ierr = MatCreateVecs(A,&Na->rwork,&Na->lwork);CHKERRQ(ierr); 328 ierr = VecCreate(PetscObjectComm((PetscObject)isrow),&left);CHKERRQ(ierr); 329 ierr = VecCreate(PetscObjectComm((PetscObject)iscol),&right);CHKERRQ(ierr); 330 ierr = VecSetSizes(left,m,PETSC_DETERMINE);CHKERRQ(ierr); 331 ierr = VecSetSizes(right,n,PETSC_DETERMINE);CHKERRQ(ierr); 332 ierr = VecSetUp(left);CHKERRQ(ierr); 333 ierr = VecSetUp(right);CHKERRQ(ierr); 334 ierr = VecScatterCreate(Na->lwork,isrow,left,NULL,&Na->lrestrict);CHKERRQ(ierr); 335 ierr = VecScatterCreate(right,NULL,Na->rwork,iscol,&Na->rprolong);CHKERRQ(ierr); 336 ierr = VecDestroy(&left);CHKERRQ(ierr); 337 ierr = VecDestroy(&right);CHKERRQ(ierr); 338 339 N->assembled = PETSC_TRUE; 340 341 ierr = MatSetUp(N);CHKERRQ(ierr); 342 343 *newmat = N; 344 PetscFunctionReturn(0); 345 } 346 347 348 /*@ 349 MatSubMatrixVirtualUpdate - Updates a submatrix 350 351 Collective on Mat 352 353 Input Parameters: 354 + N - submatrix to update 355 . A - full matrix in the submatrix 356 . isrow - rows in the update (same as the first time the submatrix was created) 357 - iscol - columns in the update (same as the first time the submatrix was created) 358 359 Level: developer 360 361 Notes: 362 Most will use MatCreateSubMatrix which provides a more efficient representation if it is available. 363 364 .seealso: MatCreateSubMatrixVirtual() 365 @*/ 366 PetscErrorCode MatSubMatrixVirtualUpdate(Mat N,Mat A,IS isrow,IS iscol) 367 { 368 PetscErrorCode ierr; 369 PetscBool flg; 370 Mat_SubVirtual *Na; 371 372 PetscFunctionBegin; 373 PetscValidHeaderSpecific(N,MAT_CLASSID,1); 374 PetscValidHeaderSpecific(A,MAT_CLASSID,2); 375 PetscValidHeaderSpecific(isrow,IS_CLASSID,3); 376 PetscValidHeaderSpecific(iscol,IS_CLASSID,4); 377 ierr = PetscObjectTypeCompare((PetscObject)N,MATSUBMATRIX,&flg);CHKERRQ(ierr); 378 if (!flg) SETERRQ(PetscObjectComm((PetscObject)A),PETSC_ERR_ARG_WRONG,"Matrix has wrong type"); 379 380 Na = (Mat_SubVirtual*)N->data; 381 ierr = ISEqual(isrow,Na->isrow,&flg);CHKERRQ(ierr); 382 if (!flg) SETERRQ(PETSC_COMM_SELF,PETSC_ERR_ARG_INCOMP,"Cannot update submatrix with different row indices"); 383 ierr = ISEqual(iscol,Na->iscol,&flg);CHKERRQ(ierr); 384 if (!flg) SETERRQ(PETSC_COMM_SELF,PETSC_ERR_ARG_INCOMP,"Cannot update submatrix with different column indices"); 385 386 ierr = PetscObjectReference((PetscObject)A);CHKERRQ(ierr); 387 ierr = MatDestroy(&Na->A);CHKERRQ(ierr); 388 Na->A = A; 389 390 Na->scale = 1.0; 391 ierr = VecDestroy(&Na->left);CHKERRQ(ierr); 392 ierr = VecDestroy(&Na->right);CHKERRQ(ierr); 393 PetscFunctionReturn(0); 394 } 395