Kokkos Core Kernels Package  Version of the Day
Kokkos_Half.hpp
1 //@HEADER
2 // ************************************************************************
3 //
4 // Kokkos v. 4.0
5 // Copyright (2022) National Technology & Engineering
6 // Solutions of Sandia, LLC (NTESS).
7 //
8 // Under the terms of Contract DE-NA0003525 with NTESS,
9 // the U.S. Government retains certain rights in this software.
10 //
11 // Part of Kokkos, under the Apache License v2.0 with LLVM Exceptions.
12 // See https://kokkos.org/LICENSE for license information.
13 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
14 //
15 //@HEADER
16 
17 #ifndef KOKKOS_HALF_HPP_
18 #define KOKKOS_HALF_HPP_
19 #ifndef KOKKOS_IMPL_PUBLIC_INCLUDE
20 #define KOKKOS_IMPL_PUBLIC_INCLUDE
21 #define KOKKOS_IMPL_PUBLIC_INCLUDE_NOTDEFINED_HALF
22 #endif
23 
24 #include <type_traits>
25 #include <Kokkos_Macros.hpp>
26 #include <iosfwd> // istream & ostream for extraction and insertion ops
27 #include <string>
28 
29 #ifdef KOKKOS_IMPL_HALF_TYPE_DEFINED
30 
31 // KOKKOS_HALF_IS_FULL_TYPE_ON_ARCH: A macro to select which
32 // floating_pointer_wrapper operator paths should be used. For CUDA, let the
33 // compiler conditionally select when device ops are used For SYCL, we have a
34 // full half type on both host and device
35 #if defined(__CUDA_ARCH__) || defined(KOKKOS_ENABLE_SYCL)
36 #define KOKKOS_HALF_IS_FULL_TYPE_ON_ARCH
37 #endif
38 
39 /************************* BEGIN forward declarations *************************/
40 namespace Kokkos {
41 namespace Experimental {
42 namespace Impl {
43 template <class FloatType>
44 class floating_point_wrapper;
45 }
46 
47 // Declare half_t (binary16)
48 using half_t = Kokkos::Experimental::Impl::floating_point_wrapper<
49  Kokkos::Impl::half_impl_t ::type>;
50 KOKKOS_INLINE_FUNCTION
51 half_t cast_to_half(float val);
52 KOKKOS_INLINE_FUNCTION
53 half_t cast_to_half(bool val);
54 KOKKOS_INLINE_FUNCTION
55 half_t cast_to_half(double val);
56 KOKKOS_INLINE_FUNCTION
57 half_t cast_to_half(short val);
58 KOKKOS_INLINE_FUNCTION
59 half_t cast_to_half(int val);
60 KOKKOS_INLINE_FUNCTION
61 half_t cast_to_half(long val);
62 KOKKOS_INLINE_FUNCTION
63 half_t cast_to_half(long long val);
64 KOKKOS_INLINE_FUNCTION
65 half_t cast_to_half(unsigned short val);
66 KOKKOS_INLINE_FUNCTION
67 half_t cast_to_half(unsigned int val);
68 KOKKOS_INLINE_FUNCTION
69 half_t cast_to_half(unsigned long val);
70 KOKKOS_INLINE_FUNCTION
71 half_t cast_to_half(unsigned long long val);
72 KOKKOS_INLINE_FUNCTION
73 half_t cast_to_half(half_t);
74 
75 template <class T>
76 KOKKOS_INLINE_FUNCTION std::enable_if_t<std::is_same<T, float>::value, T>
77  cast_from_half(half_t);
78 template <class T>
79 KOKKOS_INLINE_FUNCTION std::enable_if_t<std::is_same<T, bool>::value, T>
80  cast_from_half(half_t);
81 template <class T>
82 KOKKOS_INLINE_FUNCTION std::enable_if_t<std::is_same<T, double>::value, T>
83  cast_from_half(half_t);
84 template <class T>
85 KOKKOS_INLINE_FUNCTION std::enable_if_t<std::is_same<T, short>::value, T>
86  cast_from_half(half_t);
87 template <class T>
88 KOKKOS_INLINE_FUNCTION std::enable_if_t<std::is_same<T, int>::value, T>
89  cast_from_half(half_t);
90 template <class T>
91 KOKKOS_INLINE_FUNCTION std::enable_if_t<std::is_same<T, long>::value, T>
92  cast_from_half(half_t);
93 template <class T>
94 KOKKOS_INLINE_FUNCTION std::enable_if_t<std::is_same<T, long long>::value, T>
95  cast_from_half(half_t);
96 template <class T>
97 KOKKOS_INLINE_FUNCTION
98  std::enable_if_t<std::is_same<T, unsigned short>::value, T>
99  cast_from_half(half_t);
100 template <class T>
101 KOKKOS_INLINE_FUNCTION std::enable_if_t<std::is_same<T, unsigned int>::value, T>
102  cast_from_half(half_t);
103 template <class T>
104 KOKKOS_INLINE_FUNCTION
105  std::enable_if_t<std::is_same<T, unsigned long>::value, T>
106  cast_from_half(half_t);
107 template <class T>
108 KOKKOS_INLINE_FUNCTION
109  std::enable_if_t<std::is_same<T, unsigned long long>::value, T>
110  cast_from_half(half_t);
111 
112 // declare bhalf_t
113 #ifdef KOKKOS_IMPL_BHALF_TYPE_DEFINED
114 using bhalf_t = Kokkos::Experimental::Impl::floating_point_wrapper<
115  Kokkos::Impl ::bhalf_impl_t ::type>;
116 
117 KOKKOS_INLINE_FUNCTION
118 bhalf_t cast_to_bhalf(float val);
119 KOKKOS_INLINE_FUNCTION
120 bhalf_t cast_to_bhalf(bool val);
121 KOKKOS_INLINE_FUNCTION
122 bhalf_t cast_to_bhalf(double val);
123 KOKKOS_INLINE_FUNCTION
124 bhalf_t cast_to_bhalf(short val);
125 KOKKOS_INLINE_FUNCTION
126 bhalf_t cast_to_bhalf(int val);
127 KOKKOS_INLINE_FUNCTION
128 bhalf_t cast_to_bhalf(long val);
129 KOKKOS_INLINE_FUNCTION
130 bhalf_t cast_to_bhalf(long long val);
131 KOKKOS_INLINE_FUNCTION
132 bhalf_t cast_to_bhalf(unsigned short val);
133 KOKKOS_INLINE_FUNCTION
134 bhalf_t cast_to_bhalf(unsigned int val);
135 KOKKOS_INLINE_FUNCTION
136 bhalf_t cast_to_bhalf(unsigned long val);
137 KOKKOS_INLINE_FUNCTION
138 bhalf_t cast_to_bhalf(unsigned long long val);
139 KOKKOS_INLINE_FUNCTION
140 bhalf_t cast_to_bhalf(bhalf_t val);
141 
142 template <class T>
143 KOKKOS_INLINE_FUNCTION std::enable_if_t<std::is_same<T, float>::value, T>
144  cast_from_bhalf(bhalf_t);
145 template <class T>
146 KOKKOS_INLINE_FUNCTION std::enable_if_t<std::is_same<T, bool>::value, T>
147  cast_from_bhalf(bhalf_t);
148 template <class T>
149 KOKKOS_INLINE_FUNCTION std::enable_if_t<std::is_same<T, double>::value, T>
150  cast_from_bhalf(bhalf_t);
151 template <class T>
152 KOKKOS_INLINE_FUNCTION std::enable_if_t<std::is_same<T, short>::value, T>
153  cast_from_bhalf(bhalf_t);
154 template <class T>
155 KOKKOS_INLINE_FUNCTION std::enable_if_t<std::is_same<T, int>::value, T>
156  cast_from_bhalf(bhalf_t);
157 template <class T>
158 KOKKOS_INLINE_FUNCTION std::enable_if_t<std::is_same<T, long>::value, T>
159  cast_from_bhalf(bhalf_t);
160 template <class T>
161 KOKKOS_INLINE_FUNCTION std::enable_if_t<std::is_same<T, long long>::value, T>
162  cast_from_bhalf(bhalf_t);
163 template <class T>
164 KOKKOS_INLINE_FUNCTION
165  std::enable_if_t<std::is_same<T, unsigned short>::value, T>
166  cast_from_bhalf(bhalf_t);
167 template <class T>
168 KOKKOS_INLINE_FUNCTION std::enable_if_t<std::is_same<T, unsigned int>::value, T>
169  cast_from_bhalf(bhalf_t);
170 template <class T>
171 KOKKOS_INLINE_FUNCTION
172  std::enable_if_t<std::is_same<T, unsigned long>::value, T>
173  cast_from_bhalf(bhalf_t);
174 template <class T>
175 KOKKOS_INLINE_FUNCTION
176  std::enable_if_t<std::is_same<T, unsigned long long>::value, T>
177  cast_from_bhalf(bhalf_t);
178 #endif // KOKKOS_IMPL_BHALF_TYPE_DEFINED
179 
180 template <class T>
181 static KOKKOS_INLINE_FUNCTION Kokkos::Experimental::half_t cast_to_wrapper(
182  T x, const volatile Kokkos::Impl::half_impl_t::type&);
183 
184 #ifdef KOKKOS_IMPL_BHALF_TYPE_DEFINED
185 template <class T>
186 static KOKKOS_INLINE_FUNCTION Kokkos::Experimental::bhalf_t cast_to_wrapper(
187  T x, const volatile Kokkos::Impl::bhalf_impl_t::type&);
188 #endif // KOKKOS_IMPL_BHALF_TYPE_DEFINED
189 
190 template <class T>
191 static KOKKOS_INLINE_FUNCTION T
192 cast_from_wrapper(const Kokkos::Experimental::half_t& x);
193 
194 #ifdef KOKKOS_IMPL_BHALF_TYPE_DEFINED
195 template <class T>
196 static KOKKOS_INLINE_FUNCTION T
197 cast_from_wrapper(const Kokkos::Experimental::bhalf_t& x);
198 #endif // KOKKOS_IMPL_BHALF_TYPE_DEFINED
199 /************************** END forward declarations **************************/
200 
201 namespace Impl {
202 template <class FloatType>
203 class alignas(FloatType) floating_point_wrapper {
204  public:
205  using impl_type = FloatType;
206 
207  private:
208  impl_type val;
209  using fixed_width_integer_type = std::conditional_t<
210  sizeof(impl_type) == 2, uint16_t,
211  std::conditional_t<
212  sizeof(impl_type) == 4, uint32_t,
213  std::conditional_t<sizeof(impl_type) == 8, uint64_t, void>>>;
214  static_assert(!std::is_void<fixed_width_integer_type>::value,
215  "Invalid impl_type");
216 
217  public:
218  // In-class initialization and defaulted default constructors not used
219  // since Cuda supports half precision initialization via the below constructor
220  KOKKOS_FUNCTION
221  floating_point_wrapper() : val(0.0F) {}
222 
223 // Copy constructors
224 // Getting "C2580: multiple versions of a defaulted special
225 // member function are not allowed" with VS 16.11.3 and CUDA 11.4.2
226 #if defined(_WIN32) && defined(KOKKOS_ENABLE_CUDA)
227  KOKKOS_FUNCTION
228  floating_point_wrapper(const floating_point_wrapper& rhs) : val(rhs.val) {}
229 #else
230  KOKKOS_DEFAULTED_FUNCTION
231  floating_point_wrapper(const floating_point_wrapper&) noexcept = default;
232 #endif
233 
234  KOKKOS_INLINE_FUNCTION
235  floating_point_wrapper(const volatile floating_point_wrapper& rhs) {
236 #if defined(KOKKOS_HALF_IS_FULL_TYPE_ON_ARCH) && !defined(KOKKOS_ENABLE_SYCL)
237  val = rhs.val;
238 #else
239  const volatile fixed_width_integer_type* rv_ptr =
240  reinterpret_cast<const volatile fixed_width_integer_type*>(&rhs.val);
241  const fixed_width_integer_type rv_val = *rv_ptr;
242  val = reinterpret_cast<const impl_type&>(rv_val);
243 #endif // KOKKOS_HALF_IS_FULL_TYPE_ON_ARCH
244  }
245 
246  // Don't support implicit conversion back to impl_type.
247  // impl_type is a storage only type on host.
248  KOKKOS_FUNCTION
249  explicit operator impl_type() const { return val; }
250  KOKKOS_FUNCTION
251  explicit operator float() const { return cast_from_wrapper<float>(*this); }
252  KOKKOS_FUNCTION
253  explicit operator bool() const { return cast_from_wrapper<bool>(*this); }
254  KOKKOS_FUNCTION
255  explicit operator double() const { return cast_from_wrapper<double>(*this); }
256  KOKKOS_FUNCTION
257  explicit operator short() const { return cast_from_wrapper<short>(*this); }
258  KOKKOS_FUNCTION
259  explicit operator int() const { return cast_from_wrapper<int>(*this); }
260  KOKKOS_FUNCTION
261  explicit operator long() const { return cast_from_wrapper<long>(*this); }
262  KOKKOS_FUNCTION
263  explicit operator long long() const {
264  return cast_from_wrapper<long long>(*this);
265  }
266  KOKKOS_FUNCTION
267  explicit operator unsigned short() const {
268  return cast_from_wrapper<unsigned short>(*this);
269  }
270  KOKKOS_FUNCTION
271  explicit operator unsigned int() const {
272  return cast_from_wrapper<unsigned int>(*this);
273  }
274  KOKKOS_FUNCTION
275  explicit operator unsigned long() const {
276  return cast_from_wrapper<unsigned long>(*this);
277  }
278  KOKKOS_FUNCTION
279  explicit operator unsigned long long() const {
280  return cast_from_wrapper<unsigned long long>(*this);
281  }
282 
297  KOKKOS_FUNCTION
298  constexpr floating_point_wrapper(impl_type rhs) : val(rhs) {}
299  KOKKOS_FUNCTION
300  floating_point_wrapper(float rhs) : val(cast_to_wrapper(rhs, val).val) {}
301  KOKKOS_FUNCTION
302  floating_point_wrapper(double rhs) : val(cast_to_wrapper(rhs, val).val) {}
303  KOKKOS_FUNCTION
304  explicit floating_point_wrapper(bool rhs)
305  : val(cast_to_wrapper(rhs, val).val) {}
306  KOKKOS_FUNCTION
307  floating_point_wrapper(short rhs) : val(cast_to_wrapper(rhs, val).val) {}
308  KOKKOS_FUNCTION
309  floating_point_wrapper(int rhs) : val(cast_to_wrapper(rhs, val).val) {}
310  KOKKOS_FUNCTION
311  floating_point_wrapper(long rhs) : val(cast_to_wrapper(rhs, val).val) {}
312  KOKKOS_FUNCTION
313  floating_point_wrapper(long long rhs) : val(cast_to_wrapper(rhs, val).val) {}
314  KOKKOS_FUNCTION
315  floating_point_wrapper(unsigned short rhs)
316  : val(cast_to_wrapper(rhs, val).val) {}
317  KOKKOS_FUNCTION
318  floating_point_wrapper(unsigned int rhs)
319  : val(cast_to_wrapper(rhs, val).val) {}
320  KOKKOS_FUNCTION
321  floating_point_wrapper(unsigned long rhs)
322  : val(cast_to_wrapper(rhs, val).val) {}
323  KOKKOS_FUNCTION
324  floating_point_wrapper(unsigned long long rhs)
325  : val(cast_to_wrapper(rhs, val).val) {}
326 
327  // Unary operators
328  KOKKOS_FUNCTION
329  floating_point_wrapper operator+() const {
330  floating_point_wrapper tmp = *this;
331 #ifdef KOKKOS_HALF_IS_FULL_TYPE_ON_ARCH
332  tmp.val = +tmp.val;
333 #else
334  tmp.val = cast_to_wrapper(+cast_from_wrapper<float>(tmp), val).val;
335 #endif
336  return tmp;
337  }
338 
339  KOKKOS_FUNCTION
340  floating_point_wrapper operator-() const {
341  floating_point_wrapper tmp = *this;
342 #ifdef KOKKOS_HALF_IS_FULL_TYPE_ON_ARCH
343  tmp.val = -tmp.val;
344 #else
345  tmp.val = cast_to_wrapper(-cast_from_wrapper<float>(tmp), val).val;
346 #endif
347  return tmp;
348  }
349 
350  // Prefix operators
351  KOKKOS_FUNCTION
352  floating_point_wrapper& operator++() {
353 #ifdef KOKKOS_HALF_IS_FULL_TYPE_ON_ARCH
354  val = val + impl_type(1.0F); // cuda has no operator++ for __nv_bfloat
355 #else
356  float tmp = cast_from_wrapper<float>(*this);
357  ++tmp;
358  val = cast_to_wrapper(tmp, val).val;
359 #endif
360  return *this;
361  }
362 
363  KOKKOS_FUNCTION
364  floating_point_wrapper& operator--() {
365 #ifdef KOKKOS_HALF_IS_FULL_TYPE_ON_ARCH
366  val = val - impl_type(1.0F); // cuda has no operator-- for __nv_bfloat
367 #else
368  float tmp = cast_from_wrapper<float>(*this);
369  --tmp;
370  val = cast_to_wrapper(tmp, val).val;
371 #endif
372  return *this;
373  }
374 
375  // Postfix operators
376  KOKKOS_FUNCTION
377  floating_point_wrapper operator++(int) {
378  floating_point_wrapper tmp = *this;
379  operator++();
380  return tmp;
381  }
382 
383  KOKKOS_FUNCTION
384  floating_point_wrapper operator--(int) {
385  floating_point_wrapper tmp = *this;
386  operator--();
387  return tmp;
388  }
389 
390  // Binary operators
391  KOKKOS_FUNCTION
392  floating_point_wrapper& operator=(impl_type rhs) {
393  val = rhs;
394  return *this;
395  }
396 
397  template <class T>
398  KOKKOS_FUNCTION floating_point_wrapper& operator=(T rhs) {
399  val = cast_to_wrapper(rhs, val).val;
400  return *this;
401  }
402 
403  template <class T>
404  KOKKOS_FUNCTION void operator=(T rhs) volatile {
405  impl_type new_val = cast_to_wrapper(rhs, val).val;
406  volatile fixed_width_integer_type* val_ptr =
407  reinterpret_cast<volatile fixed_width_integer_type*>(
408  const_cast<impl_type*>(&val));
409  *val_ptr = reinterpret_cast<fixed_width_integer_type&>(new_val);
410  }
411 
412  // Compound operators
413  KOKKOS_FUNCTION
414  floating_point_wrapper& operator+=(floating_point_wrapper rhs) {
415 #ifdef KOKKOS_HALF_IS_FULL_TYPE_ON_ARCH
416  val = val + rhs.val; // cuda has no operator+= for __nv_bfloat
417 #else
418  val = cast_to_wrapper(
419  cast_from_wrapper<float>(*this) + cast_from_wrapper<float>(rhs),
420  val)
421  .val;
422 #endif
423  return *this;
424  }
425 
426  KOKKOS_FUNCTION
427  void operator+=(const volatile floating_point_wrapper& rhs) volatile {
428  floating_point_wrapper tmp_rhs = rhs;
429  floating_point_wrapper tmp_lhs = *this;
430 
431  tmp_lhs += tmp_rhs;
432  *this = tmp_lhs;
433  }
434 
435  // Compound operators: upcast overloads for +=
436  template <class T>
437  KOKKOS_FUNCTION friend std::enable_if_t<
438  std::is_same<T, float>::value || std::is_same<T, double>::value, T>
439  operator+=(T& lhs, floating_point_wrapper rhs) {
440  lhs += static_cast<T>(rhs);
441  return lhs;
442  }
443 
444  KOKKOS_FUNCTION
445  floating_point_wrapper& operator+=(float rhs) {
446  float result = static_cast<float>(val) + rhs;
447  val = static_cast<impl_type>(result);
448  return *this;
449  }
450 
451  KOKKOS_FUNCTION
452  floating_point_wrapper& operator+=(double rhs) {
453  double result = static_cast<double>(val) + rhs;
454  val = static_cast<impl_type>(result);
455  return *this;
456  }
457 
458  KOKKOS_FUNCTION
459  floating_point_wrapper& operator-=(floating_point_wrapper rhs) {
460 #ifdef KOKKOS_HALF_IS_FULL_TYPE_ON_ARCH
461  val = val - rhs.val; // cuda has no operator-= for __nv_bfloat
462 #else
463  val = cast_to_wrapper(
464  cast_from_wrapper<float>(*this) - cast_from_wrapper<float>(rhs),
465  val)
466  .val;
467 #endif
468  return *this;
469  }
470 
471  KOKKOS_FUNCTION
472  void operator-=(const volatile floating_point_wrapper& rhs) volatile {
473  floating_point_wrapper tmp_rhs = rhs;
474  floating_point_wrapper tmp_lhs = *this;
475 
476  tmp_lhs -= tmp_rhs;
477  *this = tmp_lhs;
478  }
479 
480  // Compund operators: upcast overloads for -=
481  template <class T>
482  KOKKOS_FUNCTION friend std::enable_if_t<
483  std::is_same<T, float>::value || std::is_same<T, double>::value, T>
484  operator-=(T& lhs, floating_point_wrapper rhs) {
485  lhs -= static_cast<T>(rhs);
486  return lhs;
487  }
488 
489  KOKKOS_FUNCTION
490  floating_point_wrapper& operator-=(float rhs) {
491  float result = static_cast<float>(val) - rhs;
492  val = static_cast<impl_type>(result);
493  return *this;
494  }
495 
496  KOKKOS_FUNCTION
497  floating_point_wrapper& operator-=(double rhs) {
498  double result = static_cast<double>(val) - rhs;
499  val = static_cast<impl_type>(result);
500  return *this;
501  }
502 
503  KOKKOS_FUNCTION
504  floating_point_wrapper& operator*=(floating_point_wrapper rhs) {
505 #ifdef KOKKOS_HALF_IS_FULL_TYPE_ON_ARCH
506  val = val * rhs.val; // cuda has no operator*= for __nv_bfloat
507 #else
508  val = cast_to_wrapper(
509  cast_from_wrapper<float>(*this) * cast_from_wrapper<float>(rhs),
510  val)
511  .val;
512 #endif
513  return *this;
514  }
515 
516  KOKKOS_FUNCTION
517  void operator*=(const volatile floating_point_wrapper& rhs) volatile {
518  floating_point_wrapper tmp_rhs = rhs;
519  floating_point_wrapper tmp_lhs = *this;
520 
521  tmp_lhs *= tmp_rhs;
522  *this = tmp_lhs;
523  }
524 
525  // Compund operators: upcast overloads for *=
526  template <class T>
527  KOKKOS_FUNCTION friend std::enable_if_t<
528  std::is_same<T, float>::value || std::is_same<T, double>::value, T>
529  operator*=(T& lhs, floating_point_wrapper rhs) {
530  lhs *= static_cast<T>(rhs);
531  return lhs;
532  }
533 
534  KOKKOS_FUNCTION
535  floating_point_wrapper& operator*=(float rhs) {
536  float result = static_cast<float>(val) * rhs;
537  val = static_cast<impl_type>(result);
538  return *this;
539  }
540 
541  KOKKOS_FUNCTION
542  floating_point_wrapper& operator*=(double rhs) {
543  double result = static_cast<double>(val) * rhs;
544  val = static_cast<impl_type>(result);
545  return *this;
546  }
547 
548  KOKKOS_FUNCTION
549  floating_point_wrapper& operator/=(floating_point_wrapper rhs) {
550 #ifdef KOKKOS_HALF_IS_FULL_TYPE_ON_ARCH
551  val = val / rhs.val; // cuda has no operator/= for __nv_bfloat
552 #else
553  val = cast_to_wrapper(
554  cast_from_wrapper<float>(*this) / cast_from_wrapper<float>(rhs),
555  val)
556  .val;
557 #endif
558  return *this;
559  }
560 
561  KOKKOS_FUNCTION
562  void operator/=(const volatile floating_point_wrapper& rhs) volatile {
563  floating_point_wrapper tmp_rhs = rhs;
564  floating_point_wrapper tmp_lhs = *this;
565 
566  tmp_lhs /= tmp_rhs;
567  *this = tmp_lhs;
568  }
569 
570  // Compund operators: upcast overloads for /=
571  template <class T>
572  KOKKOS_FUNCTION friend std::enable_if_t<
573  std::is_same<T, float>::value || std::is_same<T, double>::value, T>
574  operator/=(T& lhs, floating_point_wrapper rhs) {
575  lhs /= static_cast<T>(rhs);
576  return lhs;
577  }
578 
579  KOKKOS_FUNCTION
580  floating_point_wrapper& operator/=(float rhs) {
581  float result = static_cast<float>(val) / rhs;
582  val = static_cast<impl_type>(result);
583  return *this;
584  }
585 
586  KOKKOS_FUNCTION
587  floating_point_wrapper& operator/=(double rhs) {
588  double result = static_cast<double>(val) / rhs;
589  val = static_cast<impl_type>(result);
590  return *this;
591  }
592 
593  // Binary Arithmetic
594  KOKKOS_FUNCTION
595  friend floating_point_wrapper operator+(floating_point_wrapper lhs,
596  floating_point_wrapper rhs) {
597 #ifdef KOKKOS_HALF_IS_FULL_TYPE_ON_ARCH
598  lhs += rhs;
599 #else
600  lhs.val = cast_to_wrapper(
601  cast_from_wrapper<float>(lhs) + cast_from_wrapper<float>(rhs),
602  lhs.val)
603  .val;
604 #endif
605  return lhs;
606  }
607 
608  // Binary Arithmetic upcast operators for +
609  template <class T>
610  KOKKOS_FUNCTION friend std::enable_if_t<
611  std::is_same<T, float>::value || std::is_same<T, double>::value, T>
612  operator+(floating_point_wrapper lhs, T rhs) {
613  return T(lhs) + rhs;
614  }
615 
616  template <class T>
617  KOKKOS_FUNCTION friend std::enable_if_t<
618  std::is_same<T, float>::value || std::is_same<T, double>::value, T>
619  operator+(T lhs, floating_point_wrapper rhs) {
620  return lhs + T(rhs);
621  }
622 
623  KOKKOS_FUNCTION
624  friend floating_point_wrapper operator-(floating_point_wrapper lhs,
625  floating_point_wrapper rhs) {
626 #ifdef KOKKOS_HALF_IS_FULL_TYPE_ON_ARCH
627  lhs -= rhs;
628 #else
629  lhs.val = cast_to_wrapper(
630  cast_from_wrapper<float>(lhs) - cast_from_wrapper<float>(rhs),
631  lhs.val)
632  .val;
633 #endif
634  return lhs;
635  }
636 
637  // Binary Arithmetic upcast operators for -
638  template <class T>
639  KOKKOS_FUNCTION friend std::enable_if_t<
640  std::is_same<T, float>::value || std::is_same<T, double>::value, T>
641  operator-(floating_point_wrapper lhs, T rhs) {
642  return T(lhs) - rhs;
643  }
644 
645  template <class T>
646  KOKKOS_FUNCTION friend std::enable_if_t<
647  std::is_same<T, float>::value || std::is_same<T, double>::value, T>
648  operator-(T lhs, floating_point_wrapper rhs) {
649  return lhs - T(rhs);
650  }
651 
652  KOKKOS_FUNCTION
653  friend floating_point_wrapper operator*(floating_point_wrapper lhs,
654  floating_point_wrapper rhs) {
655 #ifdef KOKKOS_HALF_IS_FULL_TYPE_ON_ARCH
656  lhs *= rhs;
657 #else
658  lhs.val = cast_to_wrapper(
659  cast_from_wrapper<float>(lhs) * cast_from_wrapper<float>(rhs),
660  lhs.val)
661  .val;
662 #endif
663  return lhs;
664  }
665 
666  // Binary Arithmetic upcast operators for *
667  template <class T>
668  KOKKOS_FUNCTION friend std::enable_if_t<
669  std::is_same<T, float>::value || std::is_same<T, double>::value, T>
670  operator*(floating_point_wrapper lhs, T rhs) {
671  return T(lhs) * rhs;
672  }
673 
674  template <class T>
675  KOKKOS_FUNCTION friend std::enable_if_t<
676  std::is_same<T, float>::value || std::is_same<T, double>::value, T>
677  operator*(T lhs, floating_point_wrapper rhs) {
678  return lhs * T(rhs);
679  }
680 
681  KOKKOS_FUNCTION
682  friend floating_point_wrapper operator/(floating_point_wrapper lhs,
683  floating_point_wrapper rhs) {
684 #ifdef KOKKOS_HALF_IS_FULL_TYPE_ON_ARCH
685  lhs /= rhs;
686 #else
687  lhs.val = cast_to_wrapper(
688  cast_from_wrapper<float>(lhs) / cast_from_wrapper<float>(rhs),
689  lhs.val)
690  .val;
691 #endif
692  return lhs;
693  }
694 
695  // Binary Arithmetic upcast operators for /
696  template <class T>
697  KOKKOS_FUNCTION friend std::enable_if_t<
698  std::is_same<T, float>::value || std::is_same<T, double>::value, T>
699  operator/(floating_point_wrapper lhs, T rhs) {
700  return T(lhs) / rhs;
701  }
702 
703  template <class T>
704  KOKKOS_FUNCTION friend std::enable_if_t<
705  std::is_same<T, float>::value || std::is_same<T, double>::value, T>
706  operator/(T lhs, floating_point_wrapper rhs) {
707  return lhs / T(rhs);
708  }
709 
710  // Logical operators
711  KOKKOS_FUNCTION
712  bool operator!() const {
713 #ifdef KOKKOS_HALF_IS_FULL_TYPE_ON_ARCH
714  return static_cast<bool>(!val);
715 #else
716  return !cast_from_wrapper<float>(*this);
717 #endif
718  }
719 
720  // NOTE: Loses short-circuit evaluation
721  KOKKOS_FUNCTION
722  bool operator&&(floating_point_wrapper rhs) const {
723 #ifdef KOKKOS_HALF_IS_FULL_TYPE_ON_ARCH
724  return static_cast<bool>(val && rhs.val);
725 #else
726  return cast_from_wrapper<float>(*this) && cast_from_wrapper<float>(rhs);
727 #endif
728  }
729 
730  // NOTE: Loses short-circuit evaluation
731  KOKKOS_FUNCTION
732  bool operator||(floating_point_wrapper rhs) const {
733 #ifdef KOKKOS_HALF_IS_FULL_TYPE_ON_ARCH
734  return static_cast<bool>(val || rhs.val);
735 #else
736  return cast_from_wrapper<float>(*this) || cast_from_wrapper<float>(rhs);
737 #endif
738  }
739 
740  // Comparison operators
741  KOKKOS_FUNCTION
742  bool operator==(floating_point_wrapper rhs) const {
743 #ifdef KOKKOS_HALF_IS_FULL_TYPE_ON_ARCH
744  return static_cast<bool>(val == rhs.val);
745 #else
746  return cast_from_wrapper<float>(*this) == cast_from_wrapper<float>(rhs);
747 #endif
748  }
749 
750  KOKKOS_FUNCTION
751  bool operator!=(floating_point_wrapper rhs) const {
752 #ifdef KOKKOS_HALF_IS_FULL_TYPE_ON_ARCH
753  return static_cast<bool>(val != rhs.val);
754 #else
755  return cast_from_wrapper<float>(*this) != cast_from_wrapper<float>(rhs);
756 #endif
757  }
758 
759  KOKKOS_FUNCTION
760  bool operator<(floating_point_wrapper rhs) const {
761 #ifdef KOKKOS_HALF_IS_FULL_TYPE_ON_ARCH
762  return static_cast<bool>(val < rhs.val);
763 #else
764  return cast_from_wrapper<float>(*this) < cast_from_wrapper<float>(rhs);
765 #endif
766  }
767 
768  KOKKOS_FUNCTION
769  bool operator>(floating_point_wrapper rhs) const {
770 #ifdef KOKKOS_HALF_IS_FULL_TYPE_ON_ARCH
771  return static_cast<bool>(val > rhs.val);
772 #else
773  return cast_from_wrapper<float>(*this) > cast_from_wrapper<float>(rhs);
774 #endif
775  }
776 
777  KOKKOS_FUNCTION
778  bool operator<=(floating_point_wrapper rhs) const {
779 #ifdef KOKKOS_HALF_IS_FULL_TYPE_ON_ARCH
780  return static_cast<bool>(val <= rhs.val);
781 #else
782  return cast_from_wrapper<float>(*this) <= cast_from_wrapper<float>(rhs);
783 #endif
784  }
785 
786  KOKKOS_FUNCTION
787  bool operator>=(floating_point_wrapper rhs) const {
788 #ifdef KOKKOS_HALF_IS_FULL_TYPE_ON_ARCH
789  return static_cast<bool>(val >= rhs.val);
790 #else
791  return cast_from_wrapper<float>(*this) >= cast_from_wrapper<float>(rhs);
792 #endif
793  }
794 
795  KOKKOS_FUNCTION
796  friend bool operator==(const volatile floating_point_wrapper& lhs,
797  const volatile floating_point_wrapper& rhs) {
798  floating_point_wrapper tmp_lhs = lhs, tmp_rhs = rhs;
799  return tmp_lhs == tmp_rhs;
800  }
801 
802  KOKKOS_FUNCTION
803  friend bool operator!=(const volatile floating_point_wrapper& lhs,
804  const volatile floating_point_wrapper& rhs) {
805  floating_point_wrapper tmp_lhs = lhs, tmp_rhs = rhs;
806  return tmp_lhs != tmp_rhs;
807  }
808 
809  KOKKOS_FUNCTION
810  friend bool operator<(const volatile floating_point_wrapper& lhs,
811  const volatile floating_point_wrapper& rhs) {
812  floating_point_wrapper tmp_lhs = lhs, tmp_rhs = rhs;
813  return tmp_lhs < tmp_rhs;
814  }
815 
816  KOKKOS_FUNCTION
817  friend bool operator>(const volatile floating_point_wrapper& lhs,
818  const volatile floating_point_wrapper& rhs) {
819  floating_point_wrapper tmp_lhs = lhs, tmp_rhs = rhs;
820  return tmp_lhs > tmp_rhs;
821  }
822 
823  KOKKOS_FUNCTION
824  friend bool operator<=(const volatile floating_point_wrapper& lhs,
825  const volatile floating_point_wrapper& rhs) {
826  floating_point_wrapper tmp_lhs = lhs, tmp_rhs = rhs;
827  return tmp_lhs <= tmp_rhs;
828  }
829 
830  KOKKOS_FUNCTION
831  friend bool operator>=(const volatile floating_point_wrapper& lhs,
832  const volatile floating_point_wrapper& rhs) {
833  floating_point_wrapper tmp_lhs = lhs, tmp_rhs = rhs;
834  return tmp_lhs >= tmp_rhs;
835  }
836 
837  // Insertion and extraction operators
838  friend std::ostream& operator<<(std::ostream& os,
839  const floating_point_wrapper& x) {
840  const std::string out = std::to_string(static_cast<double>(x));
841  os << out;
842  return os;
843  }
844 
845  friend std::istream& operator>>(std::istream& is, floating_point_wrapper& x) {
846  std::string in;
847  is >> in;
848  x = std::stod(in);
849  return is;
850  }
851 };
852 } // namespace Impl
853 
854 // Declare wrapper overloads now that floating_point_wrapper is declared
855 template <class T>
856 static KOKKOS_INLINE_FUNCTION Kokkos::Experimental::half_t cast_to_wrapper(
857  T x, const volatile Kokkos::Impl::half_impl_t::type&) {
858  return Kokkos::Experimental::cast_to_half(x);
859 }
860 
861 #ifdef KOKKOS_IMPL_BHALF_TYPE_DEFINED
862 template <class T>
863 static KOKKOS_INLINE_FUNCTION Kokkos::Experimental::bhalf_t cast_to_wrapper(
864  T x, const volatile Kokkos::Impl::bhalf_impl_t::type&) {
865  return Kokkos::Experimental::cast_to_bhalf(x);
866 }
867 #endif // KOKKOS_IMPL_BHALF_TYPE_DEFINED
868 
869 template <class T>
870 static KOKKOS_INLINE_FUNCTION T
871 cast_from_wrapper(const Kokkos::Experimental::half_t& x) {
872  return Kokkos::Experimental::cast_from_half<T>(x);
873 }
874 
875 #ifdef KOKKOS_IMPL_BHALF_TYPE_DEFINED
876 template <class T>
877 static KOKKOS_INLINE_FUNCTION T
878 cast_from_wrapper(const Kokkos::Experimental::bhalf_t& x) {
879  return Kokkos::Experimental::cast_from_bhalf<T>(x);
880 }
881 #endif // KOKKOS_IMPL_BHALF_TYPE_DEFINED
882 
883 } // namespace Experimental
884 } // namespace Kokkos
885 
886 #endif // KOKKOS_IMPL_HALF_TYPE_DEFINED
887 
888 // If none of the above actually did anything and defined a half precision type
889 // define a fallback implementation here using float
890 #ifndef KOKKOS_IMPL_HALF_TYPE_DEFINED
891 #define KOKKOS_IMPL_HALF_TYPE_DEFINED
892 #define KOKKOS_HALF_T_IS_FLOAT true
893 namespace Kokkos {
894 namespace Impl {
895 struct half_impl_t {
896  using type = float;
897 };
898 } // namespace Impl
899 namespace Experimental {
900 
901 using half_t = Kokkos::Impl::half_impl_t::type;
902 
903 // cast_to_half
904 KOKKOS_INLINE_FUNCTION
905 half_t cast_to_half(float val) { return half_t(val); }
906 KOKKOS_INLINE_FUNCTION
907 half_t cast_to_half(bool val) { return half_t(val); }
908 KOKKOS_INLINE_FUNCTION
909 half_t cast_to_half(double val) { return half_t(val); }
910 KOKKOS_INLINE_FUNCTION
911 half_t cast_to_half(short val) { return half_t(val); }
912 KOKKOS_INLINE_FUNCTION
913 half_t cast_to_half(unsigned short val) { return half_t(val); }
914 KOKKOS_INLINE_FUNCTION
915 half_t cast_to_half(int val) { return half_t(val); }
916 KOKKOS_INLINE_FUNCTION
917 half_t cast_to_half(unsigned int val) { return half_t(val); }
918 KOKKOS_INLINE_FUNCTION
919 half_t cast_to_half(long val) { return half_t(val); }
920 KOKKOS_INLINE_FUNCTION
921 half_t cast_to_half(unsigned long val) { return half_t(val); }
922 KOKKOS_INLINE_FUNCTION
923 half_t cast_to_half(long long val) { return half_t(val); }
924 KOKKOS_INLINE_FUNCTION
925 half_t cast_to_half(unsigned long long val) { return half_t(val); }
926 
927 // cast_from_half
928 // Using an explicit list here too, since the other ones are explicit and for
929 // example don't include char
930 template <class T>
931 KOKKOS_INLINE_FUNCTION std::enable_if_t<
932  std::is_same<T, float>::value || std::is_same<T, bool>::value ||
933  std::is_same<T, double>::value || std::is_same<T, short>::value ||
934  std::is_same<T, unsigned short>::value || std::is_same<T, int>::value ||
935  std::is_same<T, unsigned int>::value || std::is_same<T, long>::value ||
936  std::is_same<T, unsigned long>::value ||
937  std::is_same<T, long long>::value ||
938  std::is_same<T, unsigned long long>::value,
939  T>
940 cast_from_half(half_t val) {
941  return T(val);
942 }
943 
944 } // namespace Experimental
945 } // namespace Kokkos
946 
947 #else
948 #define KOKKOS_HALF_T_IS_FLOAT false
949 #endif // KOKKOS_IMPL_HALF_TYPE_DEFINED
950 
951 #ifndef KOKKOS_IMPL_BHALF_TYPE_DEFINED
952 #define KOKKOS_IMPL_BHALF_TYPE_DEFINED
953 #define KOKKOS_BHALF_T_IS_FLOAT true
954 namespace Kokkos {
955 namespace Impl {
956 struct bhalf_impl_t {
957  using type = float;
958 };
959 } // namespace Impl
960 
961 namespace Experimental {
962 
963 using bhalf_t = Kokkos::Impl::bhalf_impl_t::type;
964 
965 // cast_to_bhalf
966 KOKKOS_INLINE_FUNCTION
967 bhalf_t cast_to_bhalf(float val) { return bhalf_t(val); }
968 KOKKOS_INLINE_FUNCTION
969 bhalf_t cast_to_bhalf(bool val) { return bhalf_t(val); }
970 KOKKOS_INLINE_FUNCTION
971 bhalf_t cast_to_bhalf(double val) { return bhalf_t(val); }
972 KOKKOS_INLINE_FUNCTION
973 bhalf_t cast_to_bhalf(short val) { return bhalf_t(val); }
974 KOKKOS_INLINE_FUNCTION
975 bhalf_t cast_to_bhalf(unsigned short val) { return bhalf_t(val); }
976 KOKKOS_INLINE_FUNCTION
977 bhalf_t cast_to_bhalf(int val) { return bhalf_t(val); }
978 KOKKOS_INLINE_FUNCTION
979 bhalf_t cast_to_bhalf(unsigned int val) { return bhalf_t(val); }
980 KOKKOS_INLINE_FUNCTION
981 bhalf_t cast_to_bhalf(long val) { return bhalf_t(val); }
982 KOKKOS_INLINE_FUNCTION
983 bhalf_t cast_to_bhalf(unsigned long val) { return bhalf_t(val); }
984 KOKKOS_INLINE_FUNCTION
985 bhalf_t cast_to_bhalf(long long val) { return bhalf_t(val); }
986 KOKKOS_INLINE_FUNCTION
987 bhalf_t cast_to_bhalf(unsigned long long val) { return bhalf_t(val); }
988 
989 // cast_from_bhalf
990 template <class T>
991 KOKKOS_INLINE_FUNCTION std::enable_if_t<
992  std::is_same<T, float>::value || std::is_same<T, bool>::value ||
993  std::is_same<T, double>::value || std::is_same<T, short>::value ||
994  std::is_same<T, unsigned short>::value || std::is_same<T, int>::value ||
995  std::is_same<T, unsigned int>::value || std::is_same<T, long>::value ||
996  std::is_same<T, unsigned long>::value ||
997  std::is_same<T, long long>::value ||
998  std::is_same<T, unsigned long long>::value,
999  T>
1000 cast_from_bhalf(bhalf_t val) {
1001  return T(val);
1002 }
1003 } // namespace Experimental
1004 } // namespace Kokkos
1005 #else
1006 #define KOKKOS_BHALF_T_IS_FLOAT false
1007 #endif // KOKKOS_IMPL_BHALF_TYPE_DEFINED
1008 #ifdef KOKKOS_IMPL_PUBLIC_INCLUDE_NOTDEFINED_HALF
1009 #undef KOKKOS_IMPL_PUBLIC_INCLUDE
1010 #undef KOKKOS_IMPL_PUBLIC_INCLUDE_NOTDEFINED_HALF
1011 #endif
1012 #endif // KOKKOS_HALF_HPP_
KOKKOS_FORCEINLINE_FUNCTION constexpr bool operator<(const pair< T1, T2 > &lhs, const pair< T1, T2 > &rhs)
Less-than operator for Kokkos::pair.
KOKKOS_FORCEINLINE_FUNCTION constexpr bool operator<=(const pair< T1, T2 > &lhs, const pair< T1, T2 > &rhs)
Less-than-or-equal-to operator for Kokkos::pair.
KOKKOS_FORCEINLINE_FUNCTION constexpr bool operator>=(const pair< T1, T2 > &lhs, const pair< T1, T2 > &rhs)
Greater-than-or-equal-to operator for Kokkos::pair.
KOKKOS_FORCEINLINE_FUNCTION constexpr bool operator>(const pair< T1, T2 > &lhs, const pair< T1, T2 > &rhs)
Greater-than operator for Kokkos::pair.
Definition: dummy.cpp:17