xref: /libCEED/rust/libceed-sys/c-src/backends/hip-gen/ceed-hip-gen-operator.c (revision 9e201c85545dd39529c090846df629a32c15659b)
13d8e8822SJeremy L Thompson // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors.
23d8e8822SJeremy L Thompson // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
37d8d0e25Snbeams //
43d8e8822SJeremy L Thompson // SPDX-License-Identifier: BSD-2-Clause
57d8d0e25Snbeams //
63d8e8822SJeremy L Thompson // This file is part of CEED:  http://github.com/ceed
77d8d0e25Snbeams 
8ec3da8bcSJed Brown #include <ceed/ceed.h>
9ec3da8bcSJed Brown #include <ceed/backend.h>
103d576824SJeremy L Thompson #include <stddef.h>
117d8d0e25Snbeams #include "ceed-hip-gen.h"
127d8d0e25Snbeams #include "ceed-hip-gen-operator-build.h"
137d8d0e25Snbeams #include "../hip/ceed-hip-compile.h"
147d8d0e25Snbeams 
157d8d0e25Snbeams //------------------------------------------------------------------------------
167d8d0e25Snbeams // Destroy operator
177d8d0e25Snbeams //------------------------------------------------------------------------------
187d8d0e25Snbeams static int CeedOperatorDestroy_Hip_gen(CeedOperator op) {
197d8d0e25Snbeams   int ierr;
207d8d0e25Snbeams   CeedOperator_Hip_gen *impl;
21e15f9bd0SJeremy L Thompson   ierr = CeedOperatorGetData(op, &impl); CeedChkBackend(ierr);
22e15f9bd0SJeremy L Thompson   ierr = CeedFree(&impl); CeedChkBackend(ierr);
23e15f9bd0SJeremy L Thompson   return CEED_ERROR_SUCCESS;
247d8d0e25Snbeams }
257d8d0e25Snbeams 
267d8d0e25Snbeams //------------------------------------------------------------------------------
277d8d0e25Snbeams // Apply and add to output
287d8d0e25Snbeams //------------------------------------------------------------------------------
29*9e201c85SYohann static int CeedOperatorApplyAdd_Hip_gen(CeedOperator op, CeedVector input_vec,
30*9e201c85SYohann                                         CeedVector output_vec, CeedRequest *request) {
317d8d0e25Snbeams   int ierr;
327d8d0e25Snbeams   Ceed ceed;
33e15f9bd0SJeremy L Thompson   ierr = CeedOperatorGetCeed(op, &ceed); CeedChkBackend(ierr);
347d8d0e25Snbeams   CeedOperator_Hip_gen *data;
35e15f9bd0SJeremy L Thompson   ierr = CeedOperatorGetData(op, &data); CeedChkBackend(ierr);
367d8d0e25Snbeams   CeedQFunction qf;
377d8d0e25Snbeams   CeedQFunction_Hip_gen *qf_data;
38e15f9bd0SJeremy L Thompson   ierr = CeedOperatorGetQFunction(op, &qf); CeedChkBackend(ierr);
39e15f9bd0SJeremy L Thompson   ierr = CeedQFunctionGetData(qf, &qf_data); CeedChkBackend(ierr);
40*9e201c85SYohann   CeedInt num_elem, num_input_fields, num_output_fields;
41*9e201c85SYohann   ierr = CeedOperatorGetNumElements(op, &num_elem); CeedChkBackend(ierr);
42*9e201c85SYohann   CeedOperatorField *op_input_fields, *op_output_fields;
43*9e201c85SYohann   ierr = CeedOperatorGetFields(op, &num_input_fields, &op_input_fields,
44*9e201c85SYohann                                &num_output_fields, &op_output_fields);
45e15f9bd0SJeremy L Thompson   CeedChkBackend(ierr);
46*9e201c85SYohann   CeedQFunctionField *qf_input_fields, *qf_output_fields;
47*9e201c85SYohann   ierr = CeedQFunctionGetFields(qf, NULL, &qf_input_fields, NULL,
48*9e201c85SYohann                                 &qf_output_fields);
49e15f9bd0SJeremy L Thompson   CeedChkBackend(ierr);
50*9e201c85SYohann   CeedEvalMode eval_mode;
51*9e201c85SYohann   CeedVector vec, output_vecs[CEED_FIELD_MAX] = {};
527d8d0e25Snbeams 
537d8d0e25Snbeams   //Creation of the operator
54e15f9bd0SJeremy L Thompson   ierr = CeedHipGenOperatorBuild(op); CeedChkBackend(ierr);
557d8d0e25Snbeams 
567d8d0e25Snbeams   // Input vectors
57*9e201c85SYohann   for (CeedInt i = 0; i < num_input_fields; i++) {
58*9e201c85SYohann     ierr = CeedQFunctionFieldGetEvalMode(qf_input_fields[i], &eval_mode);
59e15f9bd0SJeremy L Thompson     CeedChkBackend(ierr);
60*9e201c85SYohann     if (eval_mode == CEED_EVAL_WEIGHT) { // Skip
61*9e201c85SYohann       data->fields.inputs[i] = NULL;
627d8d0e25Snbeams     } else {
637d8d0e25Snbeams       // Get input vector
64*9e201c85SYohann       ierr = CeedOperatorFieldGetVector(op_input_fields[i], &vec);
65*9e201c85SYohann       CeedChkBackend(ierr);
66*9e201c85SYohann       if (vec == CEED_VECTOR_ACTIVE) vec = input_vec;
67*9e201c85SYohann       ierr = CeedVectorGetArrayRead(vec, CEED_MEM_DEVICE, &data->fields.inputs[i]);
68e15f9bd0SJeremy L Thompson       CeedChkBackend(ierr);
697d8d0e25Snbeams     }
707d8d0e25Snbeams   }
717d8d0e25Snbeams 
727d8d0e25Snbeams   // Output vectors
73*9e201c85SYohann   for (CeedInt i = 0; i < num_output_fields; i++) {
74*9e201c85SYohann     ierr = CeedQFunctionFieldGetEvalMode(qf_output_fields[i], &eval_mode);
75e15f9bd0SJeremy L Thompson     CeedChkBackend(ierr);
76*9e201c85SYohann     if (eval_mode == CEED_EVAL_WEIGHT) { // Skip
77*9e201c85SYohann       data->fields.outputs[i] = NULL;
787d8d0e25Snbeams     } else {
797d8d0e25Snbeams       // Get output vector
80*9e201c85SYohann       ierr = CeedOperatorFieldGetVector(op_output_fields[i], &vec);
81e15f9bd0SJeremy L Thompson       CeedChkBackend(ierr);
82*9e201c85SYohann       if (vec == CEED_VECTOR_ACTIVE) vec = output_vec;
83*9e201c85SYohann       output_vecs[i] = vec;
847d8d0e25Snbeams       // Check for multiple output modes
857d8d0e25Snbeams       CeedInt index = -1;
867d8d0e25Snbeams       for (CeedInt j = 0; j < i; j++) {
87*9e201c85SYohann         if (vec == output_vecs[j]) {
887d8d0e25Snbeams           index = j;
897d8d0e25Snbeams           break;
907d8d0e25Snbeams         }
917d8d0e25Snbeams       }
927d8d0e25Snbeams       if (index == -1) {
93*9e201c85SYohann         ierr = CeedVectorGetArray(vec, CEED_MEM_DEVICE, &data->fields.outputs[i]);
94e15f9bd0SJeremy L Thompson         CeedChkBackend(ierr);
957d8d0e25Snbeams       } else {
96*9e201c85SYohann         data->fields.outputs[i] = data->fields.outputs[index];
977d8d0e25Snbeams       }
987d8d0e25Snbeams     }
997d8d0e25Snbeams   }
1007d8d0e25Snbeams 
1017d8d0e25Snbeams   // Get context data
102441428dfSJeremy L Thompson   ierr = CeedQFunctionGetInnerContextData(qf, CEED_MEM_DEVICE, &qf_data->d_c);
103e15f9bd0SJeremy L Thompson   CeedChkBackend(ierr);
1047d8d0e25Snbeams 
1057d8d0e25Snbeams   // Apply operator
106*9e201c85SYohann   void *opargs[] = {(void *) &num_elem, &qf_data->d_c, &data->indices,
1077d8d0e25Snbeams                     &data->fields, &data->B, &data->G, &data->W
1087d8d0e25Snbeams                    };
1097d8d0e25Snbeams   const CeedInt dim = data->dim;
110*9e201c85SYohann   const CeedInt Q_1d = data->Q_1d;
111*9e201c85SYohann   const CeedInt P_1d = data->max_P_1d;
112*9e201c85SYohann   const CeedInt thread_1d = CeedIntMax(Q_1d, P_1d);
113b3e1519bSnbeams   CeedInt block_sizes[3];
114*9e201c85SYohann   ierr = BlockGridCalculate_Hip_gen(dim, num_elem, P_1d, Q_1d, block_sizes);
115b3e1519bSnbeams   CeedChkBackend(ierr);
1167d8d0e25Snbeams   if (dim==1) {
117*9e201c85SYohann     CeedInt grid = num_elem/block_sizes[2] + ( (
118*9e201c85SYohann                      num_elem/block_sizes[2]*block_sizes[2]<num_elem)
1197d8d0e25Snbeams                    ? 1 : 0 );
120*9e201c85SYohann     CeedInt sharedMem = block_sizes[2]*thread_1d*sizeof(CeedScalar);
121b3e1519bSnbeams     ierr = CeedRunKernelDimSharedHip(ceed, data->op, grid, block_sizes[0],
122b3e1519bSnbeams                                      block_sizes[1],
123b3e1519bSnbeams                                      block_sizes[2], sharedMem, opargs);
1247d8d0e25Snbeams   } else if (dim==2) {
125*9e201c85SYohann     CeedInt grid = num_elem/block_sizes[2] + ( (
126*9e201c85SYohann                      num_elem/block_sizes[2]*block_sizes[2]<num_elem)
1277d8d0e25Snbeams                    ? 1 : 0 );
128*9e201c85SYohann     CeedInt sharedMem = block_sizes[2]*thread_1d*thread_1d*sizeof(CeedScalar);
129b3e1519bSnbeams     ierr = CeedRunKernelDimSharedHip(ceed, data->op, grid, block_sizes[0],
130b3e1519bSnbeams                                      block_sizes[1],
131b3e1519bSnbeams                                      block_sizes[2], sharedMem, opargs);
1327d8d0e25Snbeams   } else if (dim==3) {
133*9e201c85SYohann     CeedInt grid = num_elem/block_sizes[2] + ( (
134*9e201c85SYohann                      num_elem/block_sizes[2]*block_sizes[2]<num_elem)
1357d8d0e25Snbeams                    ? 1 : 0 );
136*9e201c85SYohann     CeedInt sharedMem = block_sizes[2]*thread_1d*thread_1d*sizeof(CeedScalar);
137b3e1519bSnbeams     ierr = CeedRunKernelDimSharedHip(ceed, data->op, grid, block_sizes[0],
138b3e1519bSnbeams                                      block_sizes[1],
139b3e1519bSnbeams                                      block_sizes[2], sharedMem, opargs);
1407d8d0e25Snbeams   }
141e15f9bd0SJeremy L Thompson   CeedChkBackend(ierr);
1427d8d0e25Snbeams 
1437d8d0e25Snbeams   // Restore input arrays
144*9e201c85SYohann   for (CeedInt i = 0; i < num_input_fields; i++) {
145*9e201c85SYohann     ierr = CeedQFunctionFieldGetEvalMode(qf_input_fields[i], &eval_mode);
146e15f9bd0SJeremy L Thompson     CeedChkBackend(ierr);
147*9e201c85SYohann     if (eval_mode == CEED_EVAL_WEIGHT) { // Skip
1487d8d0e25Snbeams     } else {
149*9e201c85SYohann       ierr = CeedOperatorFieldGetVector(op_input_fields[i], &vec);
150*9e201c85SYohann       CeedChkBackend(ierr);
151*9e201c85SYohann       if (vec == CEED_VECTOR_ACTIVE) vec = input_vec;
152*9e201c85SYohann       ierr = CeedVectorRestoreArrayRead(vec, &data->fields.inputs[i]);
153e15f9bd0SJeremy L Thompson       CeedChkBackend(ierr);
1547d8d0e25Snbeams     }
1557d8d0e25Snbeams   }
1567d8d0e25Snbeams 
1577d8d0e25Snbeams   // Restore output arrays
158*9e201c85SYohann   for (CeedInt i = 0; i < num_output_fields; i++) {
159*9e201c85SYohann     ierr = CeedQFunctionFieldGetEvalMode(qf_output_fields[i], &eval_mode);
160e15f9bd0SJeremy L Thompson     CeedChkBackend(ierr);
161*9e201c85SYohann     if (eval_mode == CEED_EVAL_WEIGHT) { // Skip
1627d8d0e25Snbeams     } else {
163*9e201c85SYohann       ierr = CeedOperatorFieldGetVector(op_output_fields[i], &vec);
164e15f9bd0SJeremy L Thompson       CeedChkBackend(ierr);
165*9e201c85SYohann       if (vec == CEED_VECTOR_ACTIVE) vec = output_vec;
1667d8d0e25Snbeams       // Check for multiple output modes
1677d8d0e25Snbeams       CeedInt index = -1;
1687d8d0e25Snbeams       for (CeedInt j = 0; j < i; j++) {
169*9e201c85SYohann         if (vec == output_vecs[j]) {
1707d8d0e25Snbeams           index = j;
1717d8d0e25Snbeams           break;
1727d8d0e25Snbeams         }
1737d8d0e25Snbeams       }
1747d8d0e25Snbeams       if (index == -1) {
175*9e201c85SYohann         ierr = CeedVectorRestoreArray(vec, &data->fields.outputs[i]);
176e15f9bd0SJeremy L Thompson         CeedChkBackend(ierr);
1777d8d0e25Snbeams       }
1787d8d0e25Snbeams     }
1797d8d0e25Snbeams   }
1807d8d0e25Snbeams 
1817d8d0e25Snbeams   // Restore context data
182441428dfSJeremy L Thompson   ierr = CeedQFunctionRestoreInnerContextData(qf, &qf_data->d_c);
183e15f9bd0SJeremy L Thompson   CeedChkBackend(ierr);
184441428dfSJeremy L Thompson 
185e15f9bd0SJeremy L Thompson   return CEED_ERROR_SUCCESS;
1867d8d0e25Snbeams }
1877d8d0e25Snbeams 
1887d8d0e25Snbeams //------------------------------------------------------------------------------
1897d8d0e25Snbeams // Create operator
1907d8d0e25Snbeams //------------------------------------------------------------------------------
1917d8d0e25Snbeams int CeedOperatorCreate_Hip_gen(CeedOperator op) {
1927d8d0e25Snbeams   int ierr;
1937d8d0e25Snbeams   Ceed ceed;
194e15f9bd0SJeremy L Thompson   ierr = CeedOperatorGetCeed(op, &ceed); CeedChkBackend(ierr);
1957d8d0e25Snbeams   CeedOperator_Hip_gen *impl;
1967d8d0e25Snbeams 
197e15f9bd0SJeremy L Thompson   ierr = CeedCalloc(1, &impl); CeedChkBackend(ierr);
198e15f9bd0SJeremy L Thompson   ierr = CeedOperatorSetData(op, impl); CeedChkBackend(ierr);
1997d8d0e25Snbeams 
2007d8d0e25Snbeams   ierr = CeedSetBackendFunction(ceed, "Operator", op, "ApplyAdd",
201e15f9bd0SJeremy L Thompson                                 CeedOperatorApplyAdd_Hip_gen); CeedChkBackend(ierr);
2027d8d0e25Snbeams   ierr = CeedSetBackendFunction(ceed, "Operator", op, "Destroy",
203e15f9bd0SJeremy L Thompson                                 CeedOperatorDestroy_Hip_gen); CeedChkBackend(ierr);
204e15f9bd0SJeremy L Thompson   return CEED_ERROR_SUCCESS;
2057d8d0e25Snbeams }
2067d8d0e25Snbeams //------------------------------------------------------------------------------
207