xref: /libCEED/rust/libceed-sys/c-src/backends/hip-gen/ceed-hip-gen-operator.c (revision 9123fb08d52f01bdd0d1f3a790ba84e4ab900e9f)
1 // Copyright (c) 2017-2024, Lawrence Livermore National Security, LLC and other CEED contributors.
2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
3 //
4 // SPDX-License-Identifier: BSD-2-Clause
5 //
6 // This file is part of CEED:  http://github.com/ceed
7 
8 #include <ceed.h>
9 #include <ceed/backend.h>
10 #include <ceed/jit-source/hip/hip-types.h>
11 #include <stddef.h>
12 #include <hip/hiprtc.h>
13 
14 #include "../hip/ceed-hip-common.h"
15 #include "../hip/ceed-hip-compile.h"
16 #include "ceed-hip-gen-operator-build.h"
17 #include "ceed-hip-gen.h"
18 
19 //------------------------------------------------------------------------------
20 // Destroy operator
21 //------------------------------------------------------------------------------
22 static int CeedOperatorDestroy_Hip_gen(CeedOperator op) {
23   Ceed                  ceed;
24   CeedOperator_Hip_gen *impl;
25 
26   CeedCallBackend(CeedOperatorGetCeed(op, &ceed));
27   CeedCallBackend(CeedOperatorGetData(op, &impl));
28   if (impl->points.num_per_elem) CeedCallHip(ceed, hipFree((void **)impl->points.num_per_elem));
29   CeedCallBackend(CeedFree(&impl));
30   CeedCallBackend(CeedDestroy(&ceed));
31   return CEED_ERROR_SUCCESS;
32 }
33 
34 //------------------------------------------------------------------------------
35 // Apply and add to output
36 //------------------------------------------------------------------------------
37 static int CeedOperatorApplyAdd_Hip_gen(CeedOperator op, CeedVector input_vec, CeedVector output_vec, CeedRequest *request) {
38   bool                   is_at_points, is_tensor;
39   Ceed                   ceed;
40   CeedInt                num_elem, num_input_fields, num_output_fields;
41   CeedEvalMode           eval_mode;
42   CeedVector             output_vecs[CEED_FIELD_MAX] = {NULL};
43   CeedQFunctionField    *qf_input_fields, *qf_output_fields;
44   CeedQFunction_Hip_gen *qf_data;
45   CeedQFunction          qf;
46   CeedOperatorField     *op_input_fields, *op_output_fields;
47   CeedOperator_Hip_gen  *data;
48 
49   // Check for shared bases
50   CeedCallBackend(CeedOperatorGetFields(op, &num_input_fields, &op_input_fields, &num_output_fields, &op_output_fields));
51   {
52     bool has_shared_bases = true, is_all_tensor = true, is_all_nontensor = true;
53 
54     for (CeedInt i = 0; i < num_input_fields; i++) {
55       CeedBasis basis;
56 
57       CeedCallBackend(CeedOperatorFieldGetBasis(op_input_fields[i], &basis));
58       if (basis != CEED_BASIS_NONE) {
59         bool        is_tensor = true;
60         const char *resource;
61         char       *resource_root;
62         Ceed        basis_ceed;
63 
64         CeedCallBackend(CeedBasisIsTensor(basis, &is_tensor));
65         is_all_tensor &= is_tensor;
66         is_all_nontensor &= !is_tensor;
67         CeedCallBackend(CeedBasisGetCeed(basis, &basis_ceed));
68         CeedCallBackend(CeedGetResource(basis_ceed, &resource));
69         CeedCallBackend(CeedGetResourceRoot(basis_ceed, resource, ":", &resource_root));
70         has_shared_bases &= !strcmp(resource_root, "/gpu/hip/shared");
71         CeedCallBackend(CeedFree(&resource_root));
72         CeedCallBackend(CeedDestroy(&basis_ceed));
73       }
74       CeedCallBackend(CeedBasisDestroy(&basis));
75     }
76 
77     for (CeedInt i = 0; i < num_output_fields; i++) {
78       CeedBasis basis;
79 
80       CeedCallBackend(CeedOperatorFieldGetBasis(op_output_fields[i], &basis));
81       if (basis != CEED_BASIS_NONE) {
82         bool        is_tensor = true;
83         const char *resource;
84         char       *resource_root;
85         Ceed        basis_ceed;
86 
87         CeedCallBackend(CeedBasisIsTensor(basis, &is_tensor));
88         is_all_tensor &= is_tensor;
89         is_all_nontensor &= !is_tensor;
90 
91         CeedCallBackend(CeedBasisGetCeed(basis, &basis_ceed));
92         CeedCallBackend(CeedGetResource(basis_ceed, &resource));
93         CeedCallBackend(CeedGetResourceRoot(basis_ceed, resource, ":", &resource_root));
94         has_shared_bases &= !strcmp(resource_root, "/gpu/hip/shared");
95         CeedCallBackend(CeedFree(&resource_root));
96         CeedCallBackend(CeedDestroy(&basis_ceed));
97       }
98       CeedCallBackend(CeedBasisDestroy(&basis));
99     }
100     // -- Fallback to ref if not all bases are shared
101     if (!has_shared_bases || (!is_all_tensor && !is_all_nontensor)) {
102       CeedOperator op_fallback;
103 
104       CeedDebug256(CeedOperatorReturnCeed(op), CEED_DEBUG_COLOR_SUCCESS, "Falling back to /gpu/hip/ref CeedOperator due to unsupported bases");
105       CeedCallBackend(CeedOperatorGetFallback(op, &op_fallback));
106       CeedCallBackend(CeedOperatorApplyAdd(op_fallback, input_vec, output_vec, request));
107       return CEED_ERROR_SUCCESS;
108     }
109   }
110 
111   CeedCallBackend(CeedOperatorGetCeed(op, &ceed));
112   CeedCallBackend(CeedOperatorGetData(op, &data));
113   CeedCallBackend(CeedOperatorGetQFunction(op, &qf));
114   CeedCallBackend(CeedQFunctionGetData(qf, &qf_data));
115   CeedCallBackend(CeedOperatorGetNumElements(op, &num_elem));
116   CeedCallBackend(CeedQFunctionGetFields(qf, NULL, &qf_input_fields, NULL, &qf_output_fields));
117 
118   // Creation of the operator
119   CeedCallBackend(CeedOperatorBuildKernel_Hip_gen(op));
120 
121   // Input vectors
122   for (CeedInt i = 0; i < num_input_fields; i++) {
123     CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_input_fields[i], &eval_mode));
124     if (eval_mode == CEED_EVAL_WEIGHT) {  // Skip
125       data->fields.inputs[i] = NULL;
126     } else {
127       CeedVector vec;
128 
129       // Get input vector
130       CeedCallBackend(CeedOperatorFieldGetVector(op_input_fields[i], &vec));
131       if (vec == CEED_VECTOR_ACTIVE) vec = input_vec;
132       CeedCallBackend(CeedVectorGetArrayRead(vec, CEED_MEM_DEVICE, &data->fields.inputs[i]));
133     }
134   }
135 
136   // Output vectors
137   for (CeedInt i = 0; i < num_output_fields; i++) {
138     CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_output_fields[i], &eval_mode));
139     if (eval_mode == CEED_EVAL_WEIGHT) {  // Skip
140       data->fields.outputs[i] = NULL;
141     } else {
142       CeedVector vec;
143 
144       // Get output vector
145       CeedCallBackend(CeedOperatorFieldGetVector(op_output_fields[i], &vec));
146       if (vec == CEED_VECTOR_ACTIVE) vec = output_vec;
147       output_vecs[i] = vec;
148       // Check for multiple output modes
149       CeedInt index = -1;
150       for (CeedInt j = 0; j < i; j++) {
151         if (vec == output_vecs[j]) {
152           index = j;
153           break;
154         }
155       }
156       if (index == -1) {
157         CeedCallBackend(CeedVectorGetArray(vec, CEED_MEM_DEVICE, &data->fields.outputs[i]));
158       } else {
159         data->fields.outputs[i] = data->fields.outputs[index];
160       }
161     }
162   }
163 
164   // Point coordinates, if needed
165   CeedCallBackend(CeedOperatorIsAtPoints(op, &is_at_points));
166   if (is_at_points) {
167     // Coords
168     CeedVector vec;
169 
170     CeedCallBackend(CeedOperatorAtPointsGetPoints(op, NULL, &vec));
171     CeedCallBackend(CeedVectorGetArrayRead(vec, CEED_MEM_DEVICE, &data->points.coords));
172     CeedCallBackend(CeedVectorDestroy(&vec));
173 
174     // Points per elem
175     if (num_elem != data->points.num_elem) {
176       CeedInt            *points_per_elem;
177       const CeedInt       num_bytes   = num_elem * sizeof(CeedInt);
178       CeedElemRestriction rstr_points = NULL;
179 
180       data->points.num_elem = num_elem;
181       CeedCallBackend(CeedOperatorAtPointsGetPoints(op, &rstr_points, NULL));
182       CeedCallBackend(CeedCalloc(num_elem, &points_per_elem));
183       for (CeedInt e = 0; e < num_elem; e++) {
184         CeedInt num_points_elem;
185 
186         CeedCallBackend(CeedElemRestrictionGetNumPointsInElement(rstr_points, e, &num_points_elem));
187         points_per_elem[e] = num_points_elem;
188       }
189       if (data->points.num_per_elem) CeedCallHip(ceed, hipFree((void **)data->points.num_per_elem));
190       CeedCallHip(ceed, hipMalloc((void **)&data->points.num_per_elem, num_bytes));
191       CeedCallHip(ceed, hipMemcpy((void *)data->points.num_per_elem, points_per_elem, num_bytes, hipMemcpyHostToDevice));
192       CeedCallBackend(CeedElemRestrictionDestroy(&rstr_points));
193       CeedCallBackend(CeedFree(&points_per_elem));
194     }
195   }
196 
197   // Get context data
198   CeedCallBackend(CeedQFunctionGetInnerContextData(qf, CEED_MEM_DEVICE, &qf_data->d_c));
199 
200   // Apply operator
201   void         *opargs[]  = {(void *)&num_elem, &qf_data->d_c, &data->indices, &data->fields, &data->B, &data->G, &data->W, &data->points};
202   const CeedInt dim       = data->dim;
203   const CeedInt Q_1d      = data->Q_1d;
204   const CeedInt P_1d      = data->max_P_1d;
205   const CeedInt thread_1d = CeedIntMax(Q_1d, P_1d);
206   CeedInt       block_sizes[3];
207 
208   CeedCallBackend(CeedOperatorHasTensorBases(op, &is_tensor));
209   CeedCallBackend(BlockGridCalculate_Hip_gen(is_tensor ? dim : 1, num_elem, P_1d, Q_1d, block_sizes));
210   if (dim == 1 || !is_tensor) {
211     CeedInt grid      = num_elem / block_sizes[2] + ((num_elem / block_sizes[2] * block_sizes[2] < num_elem) ? 1 : 0);
212     CeedInt sharedMem = block_sizes[2] * thread_1d * sizeof(CeedScalar);
213 
214     CeedCallBackend(CeedRunKernelDimShared_Hip(ceed, data->op, grid, block_sizes[0], block_sizes[1], block_sizes[2], sharedMem, opargs));
215   } else if (dim == 2) {
216     CeedInt grid      = num_elem / block_sizes[2] + ((num_elem / block_sizes[2] * block_sizes[2] < num_elem) ? 1 : 0);
217     CeedInt sharedMem = block_sizes[2] * thread_1d * thread_1d * sizeof(CeedScalar);
218 
219     CeedCallBackend(CeedRunKernelDimShared_Hip(ceed, data->op, grid, block_sizes[0], block_sizes[1], block_sizes[2], sharedMem, opargs));
220   } else if (dim == 3) {
221     CeedInt grid      = num_elem / block_sizes[2] + ((num_elem / block_sizes[2] * block_sizes[2] < num_elem) ? 1 : 0);
222     CeedInt sharedMem = block_sizes[2] * thread_1d * thread_1d * sizeof(CeedScalar);
223 
224     CeedCallBackend(CeedRunKernelDimShared_Hip(ceed, data->op, grid, block_sizes[0], block_sizes[1], block_sizes[2], sharedMem, opargs));
225   }
226 
227   // Restore input arrays
228   for (CeedInt i = 0; i < num_input_fields; i++) {
229     CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_input_fields[i], &eval_mode));
230     if (eval_mode == CEED_EVAL_WEIGHT) {  // Skip
231     } else {
232       CeedVector vec;
233 
234       CeedCallBackend(CeedOperatorFieldGetVector(op_input_fields[i], &vec));
235       if (vec == CEED_VECTOR_ACTIVE) vec = input_vec;
236       CeedCallBackend(CeedVectorRestoreArrayRead(vec, &data->fields.inputs[i]));
237     }
238   }
239 
240   // Restore output arrays
241   for (CeedInt i = 0; i < num_output_fields; i++) {
242     CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_output_fields[i], &eval_mode));
243     if (eval_mode == CEED_EVAL_WEIGHT) {  // Skip
244     } else {
245       CeedVector vec;
246 
247       CeedCallBackend(CeedOperatorFieldGetVector(op_output_fields[i], &vec));
248       if (vec == CEED_VECTOR_ACTIVE) vec = output_vec;
249       // Check for multiple output modes
250       CeedInt index = -1;
251 
252       for (CeedInt j = 0; j < i; j++) {
253         if (vec == output_vecs[j]) {
254           index = j;
255           break;
256         }
257       }
258       if (index == -1) {
259         CeedCallBackend(CeedVectorRestoreArray(vec, &data->fields.outputs[i]));
260       }
261     }
262   }
263 
264   // Restore point coordinates, if needed
265   if (is_at_points) {
266     CeedVector vec;
267 
268     CeedCallBackend(CeedOperatorAtPointsGetPoints(op, NULL, &vec));
269     CeedCallBackend(CeedVectorRestoreArrayRead(vec, &data->points.coords));
270     CeedCallBackend(CeedVectorDestroy(&vec));
271   }
272 
273   // Restore context data
274   CeedCallBackend(CeedQFunctionRestoreInnerContextData(qf, &qf_data->d_c));
275   CeedCallBackend(CeedDestroy(&ceed));
276   CeedCallBackend(CeedQFunctionDestroy(&qf));
277   return CEED_ERROR_SUCCESS;
278 }
279 
280 //------------------------------------------------------------------------------
281 // Create operator
282 //------------------------------------------------------------------------------
283 int CeedOperatorCreate_Hip_gen(CeedOperator op) {
284   Ceed                  ceed;
285   CeedOperator_Hip_gen *impl;
286 
287   CeedCallBackend(CeedOperatorGetCeed(op, &ceed));
288   CeedCallBackend(CeedCalloc(1, &impl));
289   CeedCallBackend(CeedOperatorSetData(op, impl));
290   CeedCallBackend(CeedSetBackendFunction(ceed, "Operator", op, "ApplyAdd", CeedOperatorApplyAdd_Hip_gen));
291   CeedCallBackend(CeedSetBackendFunction(ceed, "Operator", op, "Destroy", CeedOperatorDestroy_Hip_gen));
292   CeedCallBackend(CeedDestroy(&ceed));
293   return CEED_ERROR_SUCCESS;
294 }
295 
296 //------------------------------------------------------------------------------
297