pyAMReX
Array4.H
Go to the documentation of this file.
1 /* Copyright 2021-2023 The AMReX Community
2  *
3  * Authors: Axel Huebl
4  * License: BSD-3-Clause-LBNL
5  */
6 #pragma once
7 
8 #include "pyAMReX.H"
9 
10 #include <AMReX_Array4.H>
11 #include <AMReX_BLassert.H>
12 #include <AMReX_GpuContainers.H>
13 #include <AMReX_IntVect.H>
14 
15 #include <complex>
16 #include <cstdint>
17 #include <iterator>
18 #include <sstream>
19 #include <type_traits>
20 #include <vector>
21 
22 
23 namespace
24 {
25  // helper type traits
26  template <typename T>
27  struct get_value_type { using value_type = T; };
28  template <typename T>
29  struct get_value_type<std::complex<T>> { using value_type = T; };
30  template <typename T>
31  using get_value_type_t = typename get_value_type<T>::value_type;
32 
33  // helper to check if Array4<T> is of constant value type T
34  template <typename T>
35  constexpr bool is_not_const ()
36  {
37  return std::is_same_v<
38  std::remove_cv_t<
39  T
40  >,
41  T
42  > &&
43  std::is_same_v<
44  std::remove_cv_t<
45  get_value_type_t<T>
46  >,
47  get_value_type_t<T>
48  >;
49  }
50 }
51 
52 namespace pyAMReX
53 {
54  using namespace amrex;
55 
60  template<typename T>
61  py::dict
63  {
64  auto d = py::dict();
65  auto const len = length(a4);
66  // F->C index conversion here
67  // p[(i-begin.x)+(j-begin.y)*jstride+(k-begin.z)*kstride+n*nstride];
68  // Buffer dimensions: zero-size shall not skip dimension
69  auto shape = py::make_tuple(
70  py::ssize_t(a4.ncomp),
71  py::ssize_t(len.z <= 0 ? 1 : len.z),
72  py::ssize_t(len.y <= 0 ? 1 : len.y),
73  py::ssize_t(len.x <= 0 ? 1 : len.x) // fastest varying index
74  );
75  // buffer protocol strides are in bytes, AMReX strides are elements
76  auto const strides = py::make_tuple(
77  py::ssize_t(sizeof(T) * a4.nstride),
78  py::ssize_t(sizeof(T) * a4.kstride),
79  py::ssize_t(sizeof(T) * a4.jstride),
80  py::ssize_t(sizeof(T)) // fastest varying index
81  );
82  bool const read_only = false; // note: we could decide on is_not_const,
83  // but many libs, e.g. PyTorch, do not
84  // support read-only and will raise
85  // warnings, casting to read-write
86  d["data"] = py::make_tuple(std::intptr_t(a4.dataPtr()), read_only);
87  // note: if we want to keep the same global indexing with non-zero
88  // box small_end as in AMReX, then we can explore playing with
89  // this offset as well
90  //d["offset"] = 0; // default
91  //d["mask"] = py::none(); // default
92 
93  d["shape"] = shape;
94  // we could also set this after checking the strides are C-style contiguous:
95  //if (is_contiguous<T>(shape, strides))
96  // d["strides"] = py::none(); // C-style contiguous
97  //else
98  d["strides"] = strides;
99 
100  // type description
101  // for more complicated types, e.g., tuples/structs
102  //d["descr"] = ...;
103  // we currently only need this
104  using T_no_cv = std::remove_cv_t<T>;
105  d["typestr"] = py::format_descriptor<T_no_cv>::format();
106 
107  d["version"] = 3;
108  return d;
109  }
110 
111  template< typename T >
112  void make_Array4(py::module &m, std::string typestr)
113  {
114  using namespace amrex;
115 
116  using T_no_cv = std::remove_cv_t<T>;
117 
118  // dispatch simpler via: py::format_descriptor<T>::format() naming
119  // but note the _const suffix that might be needed
120  auto const array_name = std::string("Array4_").append(typestr);
121  py::class_< Array4<T> > py_array4(m, array_name.c_str());
122  py_array4
123  .def("__repr__",
124  [typestr](Array4<T> const & a4) {
125  std::stringstream s;
126  s << a4.size();
127  return "<amrex.Array4 of type '" + typestr +
128  "' and size '" + s.str() + "'>";
129  }
130  )
131  #if defined(AMREX_DEBUG) || defined(AMREX_BOUND_CHECK)
132  .def("index_assert", &Array4<T>::index_assert)
133  #endif
134 
135  .def_property_readonly("size", &Array4<T>::size)
136  .def_property_readonly("nComp", &Array4<T>::nComp)
137  .def_property_readonly("num_comp", &Array4<T>::nComp)
138 
139  .def(py::init< >())
140  .def(py::init< Array4<T> const & >())
141  .def(py::init< Array4<T> const &, int >())
142  .def(py::init< Array4<T> const &, int, int >())
143  //.def(py::init< T*, Dim3 const &, Dim3 const &, int >())
144 
145  /* init from a numpy or other buffer protocol array: non-owning view
146  */
147  .def(py::init([](py::array_t<T> & arr) {
148  py::buffer_info buf = arr.request();
149 
150  AMREX_ALWAYS_ASSERT_WITH_MESSAGE(buf.ndim == 3,
151  "We can only create amrex::Array4 views into 3D Python arrays at the moment.");
152  // TODO:
153  // In 2D, Array4 still needs to be accessed with (i,j,k) or (i,j,k,n), with k = 0.
154  // Likewise in 1D.
155  // We could also add support for 4D numpy arrays, treating the slowest
156  // varying index as component "n".
157 
158  if (buf.format != py::format_descriptor<T_no_cv>::format())
159  throw std::runtime_error("Incompatible format: expected '" +
160  py::format_descriptor<T_no_cv>::format() +
161  "' and received '" + buf.format + "'!");
162 
163  auto a4 = std::make_unique< Array4<T> >();
164  a4.get()->p = static_cast<T*>(buf.ptr);
165  a4.get()->begin = Dim3{0, 0, 0};
166  // C->F index conversion here
167  // p[(i-begin.x)+(j-begin.y)*jstride+(k-begin.z)*kstride+n*nstride];
168  a4.get()->end.x = (int)buf.shape.at(2); // fastest varying index
169  a4.get()->end.y = (int)buf.shape.at(1);
170  a4.get()->end.z = (int)buf.shape.at(0);
171  a4.get()->ncomp = 1;
172  // buffer protocol strides are in bytes, AMReX strides are elements
173  a4.get()->jstride = (int)buf.strides.at(1) / sizeof(T); // fastest varying index
174  a4.get()->kstride = (int)buf.strides.at(0) / sizeof(T);
175  // 3D == no component: stride here should not matter
176  a4.get()->nstride = a4.get()->kstride * (int)buf.shape.at(0);
177 
178  // todo: we could check and store here if the array buffer we got is read-only
179 
180  return a4;
181  }))
182 
183  /* init from __cuda_array_interface__: non-owning view
184  * TODO
185  */
186 
187 
188  // CPU: __array_interface__ v3
189  // https://numpy.org/doc/stable/reference/arrays.interface.html
190  .def_property_readonly("__array_interface__", [](Array4<T> const & a4) {
191  return pyAMReX::array_interface(a4);
192  })
193 
194  // CPU: __array_function__ interface (TODO)
195  //
196  // NEP 18 — A dispatch mechanism for NumPy's high level array functions.
197  // https://numpy.org/neps/nep-0018-array-function-protocol.html
198  // This enables code using NumPy to be directly operated on Array4 arrays.
199  // __array_function__ feature requires NumPy 1.16 or later.
200 
201 
202  // Nvidia GPUs: __cuda_array_interface__ v3
203  // https://numba.readthedocs.io/en/latest/cuda/cuda_array_interface.html
204  .def_property_readonly("__cuda_array_interface__", [](Array4<T> const & a4) {
205  auto d = pyAMReX::array_interface(a4);
206 
207  // data:
208  // Because the user of the interface may or may not be in the same context, the most common case is to use cuPointerGetAttribute with CU_POINTER_ATTRIBUTE_DEVICE_POINTER in the CUDA driver API (or the equivalent CUDA Runtime API) to retrieve a device pointer that is usable in the currently active context.
209  // TODO For zero-size arrays, use 0 here.
210 
211  // None or integer
212  // An optional stream upon which synchronization must take place at the point of consumption, either by synchronizing on the stream or enqueuing operations on the data on the given stream. Integer values in this entry are as follows:
213  // 0: This is disallowed as it would be ambiguous between None and the default stream, and also between the legacy and per-thread default streams. Any use case where 0 might be given should either use None, 1, or 2 instead for clarity.
214  // 1: The legacy default stream.
215  // 2: The per-thread default stream.
216  // Any other integer: a cudaStream_t represented as a Python integer.
217  // When None, no synchronization is required.
218  d["stream"] = py::none();
219 
220  d["version"] = 3;
221  return d;
222  })
223 
224 
225  // TODO: __dlpack__ __dlpack_device__
226  // DLPack protocol (CPU, NVIDIA GPU, AMD GPU, Intel GPU, etc.)
227  // https://dmlc.github.io/dlpack/latest/
228  // https://data-apis.org/array-api/latest/design_topics/data_interchange.html
229  // https://github.com/data-apis/consortium-feedback/issues/1
230  // https://github.com/dmlc/dlpack/blob/master/include/dlpack/dlpack.h
231  // https://docs.cupy.dev/en/stable/user_guide/interoperability.html#dlpack-data-exchange-protocol
232 
233  .def("to_host", [](Array4<T> const & a4) {
234  // py::tuple to std::vector
235  auto const a4i = pyAMReX::array_interface(a4);
236  auto const shape = py::cast<std::vector<py::ssize_t>>(a4i["shape"]);
237  auto const strides_bytes = py::cast<std::vector<py::ssize_t>>(a4i["strides"]);
238 
239  // allocate host memory copy
240  auto h_data = py::array_t<T_no_cv>(shape, strides_bytes);
241 
242  // sync copy: host data is unpinned
243  Gpu::copy(Gpu::deviceToHost,
244  a4.dataPtr(), a4.dataPtr() + a4.size(),
245  h_data.mutable_data()
246  );
247  return h_data;
248  }, py::return_value_policy::move)
249 
250  .def("contains", &Array4<T>::contains)
251  //.def("__contains__", &Array4<T>::contains)
252 
253  // getter
254  .def("__getitem__", [](Array4<T> & a4, IntVect const & v){ return a4(v); })
255  .def("__getitem__", [](Array4<T> & a4, std::array<int, 4> const key){
256  return a4(key[0], key[1], key[2], key[3]);
257  })
258  .def("__getitem__", [](Array4<T> & a4, std::array<int, 3> const key){
259  return a4(key[0], key[1], key[2]);
260  })
261  ;
262 
263  // setter
264  if constexpr (is_not_const<T>())
265  {
266  py_array4
267  .def("__setitem__", [](Array4<T> & a4, IntVect const & v, T const value){ a4(v) = value; })
268  .def("__setitem__", [](Array4<T> & a4, std::array<int, 4> const key, T const value){
269  a4(key[0], key[1], key[2], key[3]) = value;
270  })
271  .def("__setitem__", [](Array4<T> & a4, std::array<int, 3> const key, T const value){
272  a4(key[0], key[1], key[2]) = value;
273  })
274  ;
275  }
276 
277  // free standing C++ functions:
278  m.def("lbound", &lbound< T >);
279  m.def("ubound", &ubound< T >);
280  m.def("length", &length< T >);
281  //m.def("makePolymorphic", &makePolymorphic< T >);
282  }
283 }
#define AMREX_ALWAYS_ASSERT_WITH_MESSAGE(EX, MSG)
AMREX_GPU_HOST_DEVICE AMREX_FORCE_INLINE Dim3 length(Array4< T > const &a) noexcept
const int[]
Definition: Array4.H:53
void make_Array4(py::module &m, std::string typestr)
Definition: Array4.H:112
py::dict array_interface(Array4< T > const &a4)
Definition: Array4.H:62
AMREX_GPU_HOST_DEVICE AMREX_FORCE_INLINE std::size_t size() const noexcept
AMREX_GPU_HOST_DEVICE AMREX_FORCE_INLINE T * dataPtr() const noexcept
T *AMREX_RESTRICT p