amd_buffer_addressing.hpp Source File

amd_buffer_addressing.hpp Source File#

Composable Kernel: amd_buffer_addressing.hpp Source File
tile/core/arch/amd_buffer_addressing.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
7
8#if !CK_TILE_USE_BUFFER_ADDRESSING_BUILTIN
9
19
20// This attribute gives a hint to the compiler that a branch is likely to be taken.
21// Then, the compiler should remove if possible the associated s_cbranch_execz branch that would
22// have been generated.
23#if __cplusplus >= 202002L
24#define LIKELY(x) (x) [[likely]]
25#else
26#define LIKELY(x) (__builtin_expect(!!(x), 1))
27#endif
28
29using as3_uint32_ptr = uint32_t __attribute__((address_space(3)))*;
30
31namespace ck_tile {
32
33// amd_wave_read_first_lane is the SGPR function from AMD GPU device to load 1 or a series of the
34// memory to the SGPR registers.
36{
37 return __builtin_amdgcn_readfirstlane(static_cast<uint32_t>(v));
38}
39
41{
42 return __builtin_amdgcn_readfirstlane(static_cast<uint32_t>(v));
43}
44
46{
47 return __builtin_amdgcn_readfirstlane(value);
48}
49
51{
52 return __builtin_amdgcn_readfirstlane(value);
53}
54
55template <typename Object, std::enable_if_t<std::is_trivially_copyable_v<Object>, int> = 0>
56__device__ inline auto amd_wave_read_first_lane(const Object& obj)
57{
58 constexpr size_t ObjectSize = sizeof(Object);
59 constexpr size_t SGPR_size = 4;
60 constexpr size_t NumFull = ObjectSize / SGPR_size;
61 constexpr size_t Tail = ObjectSize % SGPR_size;
62
63 const unsigned char* src = reinterpret_cast<const unsigned char*>(&obj);
64 alignas(Object) unsigned char dst[ObjectSize];
65
66 static_for<0, NumFull, 1>{}([&](auto Ic) {
67 constexpr size_t offset = Ic * SGPR_size;
68 uint32_t read_src;
69 __builtin_memcpy(&read_src, src + offset, SGPR_size);
70 read_src = __builtin_amdgcn_readfirstlane(read_src);
71 __builtin_memcpy(dst + offset, &read_src, SGPR_size);
72 });
73
74 if constexpr(Tail != 0)
75 {
76 constexpr size_t offset = NumFull * SGPR_size;
77 uint32_t tail_loc = 0;
78 __builtin_memcpy(&tail_loc, src + offset, Tail);
79 tail_loc = __builtin_amdgcn_readfirstlane(tail_loc);
80 __builtin_memcpy(dst + offset, &tail_loc, Tail);
81 }
82 Object out;
83 __builtin_memcpy(&out, dst, ObjectSize);
84 return out;
85}
86
87// 128 bit SGPRs to supply buffer resource in buffer instructions
88// https://rocm-documentation.readthedocs.io/en/latest/GCN_ISA_Manuals/testdocbook.html#vector-memory-buffer-instructions
89struct __attribute__((packed)) buffer_resource
90{
91 const void* ptr;
94};
95
96template <typename ForceSGPR = std::false_type>
98 uint32_t size = 0xffffffff,
99 ForceSGPR = {})
100{
101 buffer_resource res{ptr, size, CK_TILE_BUFFER_RESOURCE_3RD_DWORD};
102 int32x4_t r = __builtin_bit_cast(int32x4_t, res);
103 if constexpr(std::is_same_v<ForceSGPR, std::true_type>)
104 {
106 }
107 return r;
108}
109
110namespace impl {
111// below type indicate the data type used for buffer load inline asm
112// clang-format off
113template<index_t N, typename T> struct buffer_load_trait;
114
115template<typename T> struct buffer_load_trait<16, T> { using payload_t = fp32x4_t; };
116template<typename T> struct buffer_load_trait<8 , T> { using payload_t = fp32x2_t; };
117template<typename T> struct buffer_load_trait<4 , T> { using payload_t = float; };
118template<typename T> struct buffer_load_trait<2 , T> { using payload_t = float; };
119template<typename T> struct buffer_load_trait<1 , T> { using payload_t = float; };
120
121#if CK_TILE_BUFFER_LOAD_RAW_BF16_WA
122template<> struct buffer_load_trait<16, thread_buffer<bf16_t, 8>> { using payload_t = bf16x8_t; };
123template<> struct buffer_load_trait<8 , thread_buffer<bf16_t, 4>> { using payload_t = bf16x4_t; };
124template<> struct buffer_load_trait<4 , thread_buffer<bf16_t, 2>> { using payload_t = bf16x2_t; };
125#endif
126// clang-format on
127} // namespace impl
128
129// TODO: glc/slc/...
130template <index_t bytes, bool pre_nop = false>
132
133template <index_t bytes, bool pre_nop = false>
135
136template <index_t bytes>
138
139template <index_t bytes>
141
142#pragma clang diagnostic push
143#pragma clang diagnostic ignored "-Wundefined-reinterpret-cast"
144// TODO: strict aliasing rule seems fail when reinterpret_cast between vector type
145// (exp_vector_type(xxx))
146
147#define HAS_RAW_BUFFER_BUILTINS \
148 __has_builtin(__builtin_amdgcn_raw_buffer_load_b32) && \
149 __has_builtin(__builtin_amdgcn_make_buffer_rsrc) && \
150 __has_builtin(__builtin_amdgcn_raw_buffer_store_b32)
151
152#if HAS_RAW_BUFFER_BUILTINS
153CK_TILE_DEVICE __amdgpu_buffer_rsrc_t cast_to_amdgpu_buffer_rsrc_t(int32x4_t res)
154{
155 __amdgpu_buffer_rsrc_t as_rsrc;
156 static_assert(sizeof(res) == sizeof(as_rsrc) && "Size of buffer resource should match");
157 memcpy(&as_rsrc, &res, sizeof(res));
158 return as_rsrc;
159}
160#endif
161
162template <bool pre_nop>
163struct buffer_load<16, pre_nop>
164{
165 template <typename T>
167 int32x4_t res /*buffer resource*/,
168 index_t v_offset,
169 index_t /*s_offset*/,
170 index_t i_offset /*max 0xFFF*/,
171 index_t /*flag*/ = 0,
173 {
174 static_assert(sizeof(T) == 16);
175 using mbuf_t = typename impl::buffer_load_trait<16, T>::payload_t;
176#if HAS_RAW_BUFFER_BUILTINS
177 index_t s_offset = i_offset;
178 reinterpret_cast<mbuf_t&>(value) = __builtin_amdgcn_raw_buffer_load_b128(
179 cast_to_amdgpu_buffer_rsrc_t(res), v_offset, s_offset, 0);
180#else
181 if constexpr(pre_nop)
182 asm volatile("s_nop 4\n"
183 "buffer_load_dwordx4 %0, %1, %2, 0 offen offset:%3"
184 : "+v"(reinterpret_cast<mbuf_t&>(value))
185 : "v"(v_offset), "s"(res), "n"(i_offset)
186 : "memory");
187 else
188 asm volatile("buffer_load_dwordx4 %0, %1, %2, 0 offen offset:%3"
189 : "+v"(reinterpret_cast<mbuf_t&>(value))
190 : "v"(v_offset), "s"(res), "n"(i_offset)
191 : "memory");
192#endif
193 }
194};
195
196template <bool pre_nop>
197struct buffer_load<8, pre_nop>
198{
199 template <typename T>
201 int32x4_t res /*buffer resource*/,
202 index_t v_offset,
203 index_t /*s_offset*/,
204 index_t i_offset /*max 0xFFF*/,
205 index_t /*flag*/ = 0,
207 {
208 static_assert(sizeof(T) == 8);
209 using mbuf_t = typename impl::buffer_load_trait<8, T>::payload_t;
210#if HAS_RAW_BUFFER_BUILTINS
211 index_t s_offset = i_offset;
212 reinterpret_cast<mbuf_t&>(value) = __builtin_amdgcn_raw_buffer_load_b64(
213 cast_to_amdgpu_buffer_rsrc_t(res), v_offset, s_offset, 0);
214#else
215 if constexpr(pre_nop)
216 asm volatile("s_nop 4\n"
217 "buffer_load_dwordx2 %0, %1, %2, 0 offen offset:%3"
218 : "+v"(reinterpret_cast<mbuf_t&>(value))
219 : "v"(v_offset), "s"(res), "n"(i_offset)
220 : "memory");
221 else
222 asm volatile("buffer_load_dwordx2 %0, %1, %2, 0 offen offset:%3"
223 : "+v"(reinterpret_cast<mbuf_t&>(value))
224 : "v"(v_offset), "s"(res), "n"(i_offset)
225 : "memory");
226#endif
227 }
228};
229
230template <bool pre_nop>
231struct buffer_load<4, pre_nop>
232{
233 template <typename T>
235 int32x4_t res /*buffer resource*/,
236 index_t v_offset,
237 index_t /*s_offset*/,
238 index_t i_offset /*max 0xFFF*/,
239 index_t /*flag*/ = 0,
241 {
242 static_assert(sizeof(T) == 4);
243 using mbuf_t = typename impl::buffer_load_trait<4, T>::payload_t;
244
245#if HAS_RAW_BUFFER_BUILTINS
246 index_t s_offset = i_offset;
247 reinterpret_cast<mbuf_t&>(value) = __builtin_amdgcn_raw_buffer_load_b32(
248 cast_to_amdgpu_buffer_rsrc_t(res), v_offset, s_offset, 0);
249#else
250 if constexpr(pre_nop)
251 asm volatile("s_nop 4\n"
252 "buffer_load_dword %0, %1, %2, 0 offen offset:%3"
253 : "+v"(reinterpret_cast<mbuf_t&>(value))
254 : "v"(v_offset), "s"(res), "n"(i_offset)
255 : "memory");
256 else
257 asm volatile("buffer_load_dword %0, %1, %2, 0 offen offset:%3"
258 : "+v"(reinterpret_cast<mbuf_t&>(value))
259 : "v"(v_offset), "s"(res), "n"(i_offset)
260 : "memory");
261#endif
262 }
263};
264
265template <bool pre_nop>
266struct buffer_load<2, pre_nop>
267{
268 template <typename T>
270 int32x4_t res /*buffer resource*/,
271 index_t v_offset,
272 index_t /*s_offset*/,
273 index_t i_offset /*max 0xFFF*/,
274 index_t /*flag*/ = 0,
276 {
277 static_assert(sizeof(T) == 4); // subdword is buggy, use dword buf and convert manually
278 using mbuf_t = typename impl::buffer_load_trait<2, T>::payload_t;
279
280#if HAS_RAW_BUFFER_BUILTINS
281 index_t s_offset = i_offset;
282 reinterpret_cast<mbuf_t&>(value) = __builtin_amdgcn_raw_buffer_load_b16(
283 cast_to_amdgpu_buffer_rsrc_t(res), v_offset, s_offset, 0);
284#else
285 if constexpr(pre_nop)
286 asm volatile("s_nop 4\n"
287 "buffer_load_ushort %0, %1, %2, 0 offen offset:%3"
288 : "+v"(reinterpret_cast<mbuf_t&>(value))
289 : "v"(v_offset), "s"(res), "n"(i_offset)
290 : "memory");
291 else
292 asm volatile("buffer_load_ushort %0, %1, %2, 0 offen offset:%3"
293 : "+v"(reinterpret_cast<mbuf_t&>(value))
294 : "v"(v_offset), "s"(res), "n"(i_offset)
295 : "memory");
296#endif
297 }
298};
299
300template <bool pre_nop>
301struct buffer_load<1, pre_nop>
302{
303 template <typename T>
305 int32x4_t res /*buffer resource*/,
306 index_t v_offset,
307 index_t /*s_offset*/,
308 index_t i_offset /*max 0xFFF*/,
309 index_t /*flag*/ = 0,
311 {
312 static_assert(sizeof(T) == 4);
313 using mbuf_t = typename impl::buffer_load_trait<1, T>::payload_t;
314#if HAS_RAW_BUFFER_BUILTINS
315 index_t s_offset = i_offset;
316 reinterpret_cast<mbuf_t&>(value) = __builtin_amdgcn_raw_buffer_load_b16(
317 cast_to_amdgpu_buffer_rsrc_t(res), v_offset, s_offset, 0);
318#else
319 if constexpr(pre_nop)
320 asm volatile("s_nop 4\n"
321 "buffer_load_ubyte %0, %1, %2, 0 offen offset:%3"
322 : "+v"(reinterpret_cast<mbuf_t&>(value))
323 : "v"(v_offset), "s"(res), "n"(i_offset)
324 : "memory");
325 else
326 asm volatile("buffer_load_ubyte %0, %1, %2, 0 offen offset:%3"
327 : "+v"(reinterpret_cast<mbuf_t&>(value))
328 : "v"(v_offset), "s"(res), "n"(i_offset)
329 : "memory");
330#endif
331 }
332};
333
334#if HAS_RAW_BUFFER_BUILTINS
335template <index_t bytes, bool pre_nop>
336struct buffer_load_if
337{
338 template <typename T>
339 CK_TILE_DEVICE void operator()(T& value,
340 int32x4_t res /*buffer resource*/,
341 index_t v_offset,
342 index_t s_offset,
343 index_t i_offset /*max 0xFFF*/,
344 index_t flag = 0,
346 {
347 if LIKELY(1 <= flag)
348 {
349 buffer_load<bytes, pre_nop>{}(
350 value, res, v_offset, s_offset, i_offset, flag, bool_constant<pre_nop>{});
351 }
352 }
353};
354#else
355template <bool pre_nop>
356struct buffer_load_if<16, pre_nop>
357{
358 template <typename T>
360 int32x4_t res /*buffer resource*/,
361 index_t v_offset,
362 index_t /*s_offset*/,
363 index_t i_offset /*max 0xFFF*/,
364 index_t flag = 0,
366 {
367 static_assert(sizeof(T) == 16);
368 auto saved_exec = __builtin_amdgcn_read_exec();
369 using mbuf_t = typename impl::buffer_load_trait<16, T>::payload_t;
370 static_assert(sizeof(mbuf_t) == sizeof(T));
371 if constexpr(pre_nop)
372 asm volatile("s_nop 4\n"
373 "v_cmpx_le_u32 exec, 1, %4\n"
374 "buffer_load_dwordx4 %0, %1, %2, 0 offen offset:%3\n"
375 "s_mov_b64 exec %5"
376 : "+v"(reinterpret_cast<mbuf_t&>(value))
377 : "v"(v_offset), "s"(res), "n"(i_offset), "v"(flag), "s"(saved_exec)
378 : "memory");
379 else
380 asm volatile("v_cmpx_le_u32 exec, 1, %4\n"
381 "buffer_load_dwordx4 %0, %1, %2, 0 offen offset:%3\n"
382 "s_mov_b64 exec %5"
383 : "+v"(reinterpret_cast<mbuf_t&>(value))
384 : "v"(v_offset), "s"(res), "n"(i_offset), "v"(flag), "s"(saved_exec)
385 : "memory");
386 }
387};
388
389template <bool pre_nop>
390struct buffer_load_if<8, pre_nop>
391{
392 template <typename T>
394 int32x4_t res /*buffer resource*/,
395 index_t v_offset,
396 index_t /*s_offset*/,
397 index_t i_offset /*max 0xFFF*/,
398 index_t flag = 0,
400 {
401 static_assert(sizeof(T) == 8);
402 auto saved_exec = __builtin_amdgcn_read_exec();
403 using mbuf_t = typename impl::buffer_load_trait<8, T>::payload_t;
404 if constexpr(pre_nop)
405 asm volatile("s_nop 4\n"
406 "v_cmpx_le_u32 exec, 1, %4\n"
407 "buffer_load_dwordx2 %0, %1, %2, 0 offen offset:%3\n"
408 "s_mov_b64 exec %5"
409 : "+v"(reinterpret_cast<mbuf_t&>(value))
410 : "v"(v_offset), "s"(res), "n"(i_offset), "v"(flag), "s"(saved_exec)
411 : "memory");
412 else
413 asm volatile("v_cmpx_le_u32 exec, 1, %4\n"
414 "buffer_load_dwordx2 %0, %1, %2, 0 offen offset:%3\n"
415 "s_mov_b64 exec %5"
416 : "+v"(reinterpret_cast<mbuf_t&>(value))
417 : "v"(v_offset), "s"(res), "n"(i_offset), "v"(flag), "s"(saved_exec)
418 : "memory");
419 }
420};
421
422template <bool pre_nop>
423struct buffer_load_if<4, pre_nop>
424{
425 template <typename T>
427 int32x4_t res /*buffer resource*/,
428 index_t v_offset,
429 index_t /*s_offset*/,
430 index_t i_offset /*max 0xFFF*/,
431 index_t flag = 0,
433 {
434 static_assert(sizeof(T) == 4);
435 auto saved_exec = __builtin_amdgcn_read_exec();
436 using mbuf_t = typename impl::buffer_load_trait<4, T>::payload_t;
437 if constexpr(pre_nop)
438 asm volatile("s_nop 4\n"
439 "v_cmpx_le_u32 exec, 1, %4\n"
440 "buffer_load_dword %0, %1, %2, 0 offen offset:%3\n"
441 "s_mov_b64 exec %5"
442 : "+v"(reinterpret_cast<mbuf_t&>(value))
443 : "v"(v_offset), "s"(res), "n"(i_offset), "v"(flag), "s"(saved_exec)
444 : "memory");
445 else
446 asm volatile("v_cmpx_le_u32 exec, 1, %4\n"
447 "buffer_load_dword %0, %1, %2, 0 offen offset:%3\n"
448 "s_mov_b64 exec %5"
449 : "+v"(reinterpret_cast<mbuf_t&>(value))
450 : "v"(v_offset), "s"(res), "n"(i_offset), "v"(flag), "s"(saved_exec)
451 : "memory");
452 }
453};
454
455template <bool pre_nop>
456struct buffer_load_if<2, pre_nop>
457{
458 template <typename T>
460 int32x4_t res /*buffer resource*/,
461 index_t v_offset,
462 index_t /*s_offset*/,
463 index_t i_offset /*max 0xFFF*/,
464 index_t flag = 0,
466 {
467 static_assert(sizeof(T) == 4);
468 auto saved_exec = __builtin_amdgcn_read_exec();
469 using mbuf_t = typename impl::buffer_load_trait<2, T>::payload_t;
470 if constexpr(pre_nop)
471 asm volatile("s_nop 4\n"
472 "v_cmpx_le_u32 exec, 1, %4\n"
473 "buffer_load_ushort %0, %1, %2, 0 offen offset:%3\n"
474 "s_mov_b64 exec %5"
475 : "+v"(reinterpret_cast<mbuf_t&>(value))
476 : "v"(v_offset), "s"(res), "n"(i_offset), "v"(flag), "s"(saved_exec)
477 : "memory");
478 else
479 asm volatile("v_cmpx_le_u32 exec, 1, %4\n"
480 "buffer_load_ushort %0, %1, %2, 0 offen offset:%3\n"
481 "s_mov_b64 exec %5"
482 : "+v"(reinterpret_cast<mbuf_t&>(value))
483 : "v"(v_offset), "s"(res), "n"(i_offset), "v"(flag), "s"(saved_exec)
484 : "memory");
485 }
486};
487
488template <bool pre_nop>
489struct buffer_load_if<1, pre_nop>
490{
491 template <typename T>
493 int32x4_t res /*buffer resource*/,
494 index_t v_offset,
495 index_t /*s_offset*/,
496 index_t i_offset /*max 0xFFF*/,
497 index_t flag = 0,
499 {
500 static_assert(sizeof(T) == 4);
501 auto saved_exec = __builtin_amdgcn_read_exec();
502 using mbuf_t = typename impl::buffer_load_trait<1, T>::payload_t;
503 if constexpr(pre_nop)
504 asm volatile("s_nop 4\n"
505 "v_cmpx_le_u32 exec, 1, %4\n"
506 "buffer_load_ubyte %0, %1, %2, 0 offen offset:%3\n"
507 "s_mov_b64 exec %5"
508 : "+v"(reinterpret_cast<mbuf_t&>(value))
509 : "v"(v_offset), "s"(res), "n"(i_offset), "v"(flag), "s"(saved_exec)
510 : "memory");
511 else
512 asm volatile("v_cmpx_le_u32 exec, 1, %4\n"
513 "buffer_load_ubyte %0, %1, %2, 0 offen offset:%3\n"
514 "s_mov_b64 exec %5"
515 : "+v"(reinterpret_cast<mbuf_t&>(value))
516 : "v"(v_offset), "s"(res), "n"(i_offset), "v"(flag), "s"(saved_exec)
517 : "memory");
518 }
519};
520#endif
521
522#pragma clang diagnostic pop // "-Wundefined-reinterpret-cast"
523
524template <>
525struct buffer_store<16>
526{
527 template <typename T>
529 int32x4_t res /*buffer resource*/,
530 index_t v_offset,
531 index_t /*s_offset*/,
532 index_t i_offset /*max 0xFFF*/,
533 index_t /*flag*/ = 1)
534 {
535 static_assert(sizeof(T) == 16);
536 using mbuf_t = uint32x4_t;
537#if HAS_RAW_BUFFER_BUILTINS
538 index_t s_offset = i_offset;
539 __builtin_amdgcn_raw_buffer_store_b128(
540 bit_cast<mbuf_t>(value), cast_to_amdgpu_buffer_rsrc_t(res), v_offset, s_offset, 0);
541#else
542 asm volatile("buffer_store_dwordx4 %0, %1, %2, 0 offen offset:%3"
543 :
544 : "v"(bit_cast<mbuf_t>(value)), "v"(v_offset), "s"(res), "n"(i_offset)
545 : "memory");
546#endif
547 }
548};
549
550template <>
552{
553 template <typename T>
555 int32x4_t res /*buffer resource*/,
556 index_t v_offset,
557 index_t /*s_offset*/,
558 index_t i_offset /*max 0xFFF*/,
559 index_t /*flag*/ = 1)
560 {
561 static_assert(sizeof(T) == 8);
562 using mbuf_t = uint32x2_t;
563#if HAS_RAW_BUFFER_BUILTINS
564 index_t s_offset = i_offset;
565 __builtin_amdgcn_raw_buffer_store_b64(
566 bit_cast<mbuf_t>(value), cast_to_amdgpu_buffer_rsrc_t(res), v_offset, s_offset, 0);
567#else
568 asm volatile("buffer_store_dwordx2 %0, %1, %2, 0 offen offset:%3"
569 :
570 : "v"(bit_cast<mbuf_t>(value)), "v"(v_offset), "s"(res), "n"(i_offset)
571 : "memory");
572#endif
573 }
574};
575
576template <>
578{
579 template <typename T>
581 int32x4_t res /*buffer resource*/,
582 index_t v_offset,
583 index_t /*s_offset*/,
584 index_t i_offset /*max 0xFFF*/,
585 index_t /*flag*/ = 1)
586 {
587 static_assert(sizeof(T) == 4);
588 using mbuf_t = uint32_t;
589#if HAS_RAW_BUFFER_BUILTINS
590 index_t s_offset = i_offset;
591 __builtin_amdgcn_raw_buffer_store_b32(
592 bit_cast<mbuf_t>(value), cast_to_amdgpu_buffer_rsrc_t(res), v_offset, s_offset, 0);
593#else
594 asm volatile("buffer_store_dword %0, %1, %2, 0 offen offset:%3"
595 :
596 : "v"(bit_cast<mbuf_t>(value)), "v"(v_offset), "s"(res), "n"(i_offset)
597 : "memory");
598#endif
599 }
600};
601
602template <>
604{
605 template <typename T>
607 int32x4_t res /*buffer resource*/,
608 index_t v_offset,
609 index_t /*s_offset*/,
610 index_t i_offset /*max 0xFFF*/,
611 index_t /*flag*/ = 1)
612 {
613 static_assert(sizeof(T) == 2);
614 using mbuf_t = uint16_t;
615#if HAS_RAW_BUFFER_BUILTINS
616 index_t s_offset = i_offset;
617 __builtin_amdgcn_raw_buffer_store_b16(
618 bit_cast<mbuf_t>(value), cast_to_amdgpu_buffer_rsrc_t(res), v_offset, s_offset, 0);
619#else
620 asm volatile("buffer_store_short %0, %1, %2, 0 offen offset:%3"
621 :
622 : "v"(bit_cast<mbuf_t>(value)), "v"(v_offset), "s"(res), "n"(i_offset)
623 : "memory");
624#endif
625 }
626};
627
628template <>
630{
631 template <typename T>
633 int32x4_t res /*buffer resource*/,
634 index_t v_offset,
635 index_t /*s_offset*/,
636 index_t i_offset /*max 0xFFF*/,
637 index_t /*flag*/ = 1)
638 {
639 static_assert(sizeof(T) == 1);
640 using mbuf_t = uint8_t;
641#if HAS_RAW_BUFFER_BUILTINS
642 index_t s_offset = i_offset;
643 __builtin_amdgcn_raw_buffer_store_b8(
644 bit_cast<mbuf_t>(value), cast_to_amdgpu_buffer_rsrc_t(res), v_offset, s_offset, 0);
645#else
646 asm volatile("buffer_store_byte %0, %1, %2, 0 offen offset:%3"
647 :
648 : "v"(bit_cast<mbuf_t>(value)), "v"(v_offset), "s"(res), "n"(i_offset)
649 : "memory");
650#endif
651 }
652};
653
654#if HAS_RAW_BUFFER_BUILTINS
655template <index_t bytes>
656struct buffer_store_if
657{
658 template <typename T>
659 CK_TILE_DEVICE void operator()(const T& value,
660 int32x4_t res /*buffer resource*/,
661 index_t v_offset,
662 index_t s_offset,
663 index_t i_offset /*max 0xFFF*/,
664 index_t flag = 1)
665 {
666 if LIKELY(1 <= flag)
667 {
668 buffer_store<bytes>{}(value, res, v_offset, s_offset, i_offset);
669 }
670 }
671};
672#else
673template <>
675{
676 template <typename T>
678 int32x4_t res /*buffer resource*/,
679 index_t v_offset,
680 index_t /*s_offset*/,
681 index_t i_offset /*max 0xFFF*/,
682 index_t flag = 1)
683 {
684 static_assert(sizeof(T) == 16);
685 auto save_exec = __builtin_amdgcn_read_exec();
686 using mbuf_t = fp32x4_t;
687 asm volatile("v_cmpx_le_u32 exec, 1, %4\n"
688 "buffer_store_dwordx4 %0, %1, %2, 0 offen offset:%3\n"
689 "s_mov_b64 exec %5"
690 :
691 : "v"(bit_cast<mbuf_t>(value)),
692 "v"(v_offset),
693 "s"(res),
694 "n"(i_offset),
695 "v"(flag),
696 "s"(save_exec)
697 : "memory");
698 }
699};
700
701template <>
703{
704 template <typename T>
706 int32x4_t res /*buffer resource*/,
707 index_t v_offset,
708 index_t /*s_offset*/,
709 index_t i_offset /*max 0xFFF*/,
710 index_t flag = 1)
711 {
712 static_assert(sizeof(T) == 8);
713 auto save_exec = __builtin_amdgcn_read_exec();
714 // TODO: ugly. rocm-6.0/6.1 seems neet bit_cast to same base type to avoid scratch
715 using mbuf_t = ext_vector_t<typename T::value_type, T::size()>;
716 asm volatile("v_cmpx_le_u32 exec, 1, %4\n"
717 "buffer_store_dwordx2 %0, %1, %2, 0 offen offset:%3\n"
718 "s_mov_b64 exec %5"
719 :
720 : "v"(bit_cast<mbuf_t>(value)),
721 "v"(v_offset),
722 "s"(res),
723 "n"(i_offset),
724 "v"(flag),
725 "s"(save_exec)
726 : "memory");
727 }
728};
729
730template <>
732{
733 template <typename T>
735 int32x4_t res /*buffer resource*/,
736 index_t v_offset,
737 index_t /*s_offset*/,
738 index_t i_offset /*max 0xFFF*/,
739 index_t flag = 1)
740 {
741 static_assert(sizeof(T) == 4);
742 auto save_exec = __builtin_amdgcn_read_exec();
743 using mbuf_t = float;
744 asm volatile("v_cmpx_le_u32 exec, 1, %4\n"
745 "buffer_store_dword %0, %1, %2, 0 offen offset:%3\n"
746 "s_mov_b64 exec %5"
747 :
748 : "v"(bit_cast<mbuf_t>(value)),
749 "v"(v_offset),
750 "s"(res),
751 "n"(i_offset),
752 "v"(flag),
753 "s"(save_exec)
754 : "memory");
755 }
756};
757
758template <>
760{
761 template <typename T>
763 int32x4_t res /*buffer resource*/,
764 index_t v_offset,
765 index_t /*s_offset*/,
766 index_t i_offset /*max 0xFFF*/,
767 index_t flag = 1)
768 {
769 static_assert(sizeof(T) == 2);
770 auto save_exec = __builtin_amdgcn_read_exec();
771 using mbuf_t = short;
772 asm volatile("v_cmpx_le_u32 exec, 1, %4\n"
773 "buffer_store_short %0, %1, %2, 0 offen offset:%3\n"
774 "s_mov_b64 exec %5"
775 :
776 : "v"(bit_cast<mbuf_t>(value)),
777 "v"(v_offset),
778 "s"(res),
779 "n"(i_offset),
780 "v"(flag),
781 "s"(save_exec)
782 : "memory");
783 }
784};
785
786template <>
788{
789 template <typename T>
791 int32x4_t res /*buffer resource*/,
792 index_t v_offset,
793 index_t /*s_offset*/,
794 index_t i_offset /*max 0xFFF*/,
795 index_t flag = 1)
796 {
797 static_assert(sizeof(T) == 4);
798 auto save_exec = __builtin_amdgcn_read_exec();
799 using mbuf_t = float;
800 asm volatile("v_cmpx_le_u32 exec, 1, %4\n"
801 "buffer_store_byte %0, %1, %2, 0 offen offset:%3\n"
802 "s_mov_b64 exec %5"
803 :
804 : "v"(bit_cast<mbuf_t>(value)),
805 "v"(v_offset),
806 "s"(res),
807 "n"(i_offset),
808 "v"(flag),
809 "s"(save_exec)
810 : "memory");
811 }
812};
813#endif
814
816{
817 asm volatile("s_waitcnt vmcnt(%0)" : : "n"(cnt) : "memory");
818}
819
821{
822 asm volatile("s_waitcnt lgkmcnt(%0)" : : "n"(cnt) : "memory");
823}
824
825template <typename scalar_type, index_t N, bool pre_nop = false>
827
828template <bool pre_nop>
829struct buffer_atomic_add_if<bf16_t, 2, pre_nop>
830{
831 template <typename T>
833 int32x4_t res /*buffer resource*/,
834 index_t v_offset,
835 index_t /*s_offset*/,
836 index_t i_offset /*max 0xFFF*/,
837 index_t flag = 1)
838 {
839 static_assert(sizeof(T) == 4);
840 auto save_exec = __builtin_amdgcn_read_exec();
841 using mbuf_t = float;
842 asm volatile("v_cmpx_le_u32 exec, 1, %4\n"
843 "global_atomic_pk_add_bf16 %0, %1, %2 offset:%3\n"
844 "s_mov_b64 exec %5"
845 :
846 : "v"(v_offset),
848 "s"(res.xy),
849 "n"(i_offset),
850 "v"(flag),
851 "s"(save_exec)
852 : "memory");
853 }
854};
855
856template <typename scalar_type, index_t N, bool pre_nop = false>
858
859template <bool pre_nop>
860struct buffer_atomic_add<bf16_t, 2, pre_nop>
861{
862 template <typename T>
864 int32x4_t res /*buffer resource*/,
865 index_t v_offset,
866 index_t /*s_offset*/,
867 index_t i_offset /*max 0xFFF*/,
868 index_t /*flag = 1*/)
869 {
870 static_assert(sizeof(T) == 4);
871 using mbuf_t = float;
872 asm volatile("global_atomic_pk_add_bf16 %0, %1, %2 offset:%3"
873 :
874 : "v"(v_offset), "v"(bit_cast<mbuf_t>(value)), "s"(res.xy), "n"(i_offset)
875 : "memory");
876 }
877};
878
879namespace impl {
880// below type indicate the data type used for buffer load inline asm
881// clang-format off
882template<index_t N, typename T> struct smem_load_trait;
883
884template<typename T> struct smem_load_trait<16, T> { using payload_t = fp32x4_t; };
885template<typename T> struct smem_load_trait<8 , T> { using payload_t = fp32x2_t; };
886template<typename T> struct smem_load_trait<4 , T> { using payload_t = float; };
887template<typename T> struct smem_load_trait<2 , T> { using payload_t = float; };
888template<typename T> struct smem_load_trait<1 , T> { using payload_t = float; };
889
890// clang-format on
891} // namespace impl
892
893// NOTE: smem load/store no need pre_nop to make sure dependency by sw, happy :)
894template <index_t>
896
897template <>
898struct smem_load<16>
899{
900 template <typename T>
901 CK_TILE_DEVICE void operator()(T& value, index_t v_offset, index_t i_offset)
902 {
903 static_assert(sizeof(T) == 16);
904 using mbuf_t = typename impl::smem_load_trait<16, T>::payload_t;
905 asm volatile("ds_read_b128 %0, %1 offset:%2"
906 : "=v"(reinterpret_cast<mbuf_t&>(value)) // ! direct write
907 : "v"(v_offset), "n"(i_offset)
908 : "memory");
909 }
910};
911
912template <>
913struct smem_load<8>
914{
915 template <typename T>
916 CK_TILE_DEVICE void operator()(T& value, index_t v_offset, index_t i_offset)
917 {
918 static_assert(sizeof(T) == 8);
919 using mbuf_t = typename impl::smem_load_trait<8, T>::payload_t;
920 asm volatile("ds_read_b64 %0, %1 offset:%2"
921 : "=v"(reinterpret_cast<mbuf_t&>(value)) // ! direct write
922 : "v"(v_offset), "n"(i_offset)
923 : "memory");
924 }
925};
926
927template <>
928struct smem_load<4>
929{
930 template <typename T>
931 CK_TILE_DEVICE void operator()(T& value, index_t v_offset, index_t i_offset)
932 {
933 static_assert(sizeof(T) == 4);
934 using mbuf_t = typename impl::smem_load_trait<4, T>::payload_t;
935 asm volatile("ds_read_b32 %0, %1 offset:%2"
936 : "=v"(reinterpret_cast<mbuf_t&>(value)) // ! direct write
937 : "v"(v_offset), "n"(i_offset)
938 : "memory");
939 }
940};
941
942template <>
943struct smem_load<2>
944{
945 template <typename T>
946 CK_TILE_DEVICE void operator()(T& value, index_t v_offset, index_t i_offset)
947 {
948 static_assert(sizeof(T) == 4); // subdword is buggy, use dword buf and convert manually
949 using mbuf_t = typename impl::smem_load_trait<1, T>::payload_t;
950 asm volatile("ds_read_u16 %0, %1 offset:%2"
951 : "=v"(reinterpret_cast<mbuf_t&>(value)) // ! direct write
952 : "v"(v_offset), "n"(i_offset)
953 : "memory");
954 }
955};
956
957template <>
958struct smem_load<1>
959{
960 template <typename T>
961 CK_TILE_DEVICE void operator()(T& value, index_t v_offset, index_t i_offset)
962 {
963 static_assert(sizeof(T) == 4);
964 using mbuf_t = typename impl::smem_load_trait<1, T>::payload_t;
965 asm volatile("ds_read_u8 %0, %1 offset:%2"
966 : "=v"(reinterpret_cast<mbuf_t&>(value)) // ! direct write
967 : "v"(v_offset), "n"(i_offset)
968 : "memory");
969 }
970};
971
972// clang-format off
973namespace impl{
974
975// can't use "+v" since there could be potential extra move(read/write)
976// use "v" can help remove such duplicated moves
977// besides, fake this as "memory" operation to force later valu after this fence
978// TODO: may have scratch (because this is memory?)
979// need to reduce extra move inside compiler
980template<index_t N>
982{
983 constexpr auto kSize = remove_cvref_t<decltype(b)>::size();
984 static_for<0, kSize, 1>{}([&](auto i){
985 asm volatile(" " : : "v"(b.get(number<i>{})) : "memory");
986 });
987}
988#if 1
989// below specialization just merge size() of dwords into single section
990template<>
992{
993 asm volatile(" " : : "v"(b.get(number<0>{})), "v"(b.get(number<1>{})) : "memory");
994}
995
996template<>
998{
999 asm volatile(" " : : "v"(b.get(number<0>{})), "v"(b.get(number<1>{})), "v"(b.get(number<2>{})) : "memory");
1000}
1001
1002template<>
1004{
1005 asm volatile(" " : : "v"(b.get(number<0>{})), "v"(b.get(number<1>{})), "v"(b.get(number<2>{})), "v"(b.get(number<3>{})) : "memory");
1006}
1007
1008template<>
1010{
1011 asm volatile(" " : : "v"(b.get(number<0>{})), "v"(b.get(number<1>{})), "v"(b.get(number<2>{})), "v"(b.get(number<3>{})),
1012 "v"(b.get(number<4>{})), "v"(b.get(number<5>{})), "v"(b.get(number<6>{})), "v"(b.get(number<7>{})) : "memory");
1013}
1014
1015template<>
1017{
1018 asm volatile(" " : : "v"(b.get(number<0>{})), "v"(b.get(number<1>{})), "v"(b.get(number<2>{})), "v"(b.get(number<3>{})),
1019 "v"(b.get(number<4>{})), "v"(b.get(number<5>{})), "v"(b.get(number<6>{})), "v"(b.get(number<7>{})),
1020 "v"(b.get(number<8>{})), "v"(b.get(number<9>{})), "v"(b.get(number<10>{})), "v"(b.get(number<11>{})),
1021 "v"(b.get(number<12>{})), "v"(b.get(number<13>{})), "v"(b.get(number<14>{})), "v"(b.get(number<15>{})) : "memory");
1022}
1023
1024template<>
1026{
1027 asm volatile(" " : : "v"(b.get(number<0>{})), "v"(b.get(number<1>{})), "v"(b.get(number<2>{})), "v"(b.get(number<3>{})),
1028 "v"(b.get(number<4>{})), "v"(b.get(number<5>{})), "v"(b.get(number<6>{})), "v"(b.get(number<7>{})),
1029 "v"(b.get(number<8>{})), "v"(b.get(number<9>{})), "v"(b.get(number<10>{})), "v"(b.get(number<11>{})),
1030 "v"(b.get(number<12>{})), "v"(b.get(number<13>{})), "v"(b.get(number<14>{})), "v"(b.get(number<15>{})),
1031 "v"(b.get(number<16>{})), "v"(b.get(number<17>{})), "v"(b.get(number<18>{})), "v"(b.get(number<19>{})),
1032 "v"(b.get(number<20>{})), "v"(b.get(number<21>{})), "v"(b.get(number<22>{})), "v"(b.get(number<23>{})),
1033 "v"(b.get(number<24>{})), "v"(b.get(number<25>{})), "v"(b.get(number<26>{})), "v"(b.get(number<27>{})),
1034 "v"(b.get(number<28>{})), "v"(b.get(number<29>{})), "v"(b.get(number<30>{})), "v"(b.get(number<31>{})) : "memory");
1035}
1036#endif
1038
1039template<typename T>
1041{
1042 // TODO: indeed we expect T to be multiple of dword. subdword is always buggy
1043 using da_type = array<float, (sizeof(T) + 3) / 4>;
1044 auto & dummy = reinterpret_cast<da_type&>(buffer);
1046}
1047
1048template<typename Tx, typename... Ty>
1049CK_TILE_DEVICE void insert_dummy_dep(Tx& bx, Ty&... by)
1050{
1051 insert_dummy_dep(bx);
1052 insert_dummy_dep(by...);
1053}
1054}
1055// clang-format on
1056template <typename... T>
1058{
1059 asm volatile("s_waitcnt vmcnt(%0)" : : "n"(cnt) : "memory");
1061}
1062
1064{
1065 asm volatile("s_waitcnt vmcnt(%0)" : : "n"(cnt) : "memory");
1066}
1067
1069{
1070 asm volatile("s_waitcnt vmcnt(%0)" : : "n"(cnt) : "memory");
1071}
1072
1073// buffer load i8
1076 index_t voffset,
1077 index_t soffset,
1078 index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.i8");
1079
1082 index_t voffset,
1083 index_t soffset,
1084 index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v2i8");
1085
1088 index_t voffset,
1089 index_t soffset,
1090 index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v4i8");
1091
1092// buffer load i16
1095 index_t voffset,
1096 index_t soffset,
1097 index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.i16");
1098
1101 index_t voffset,
1102 index_t soffset,
1103 index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v2i16");
1104
1107 index_t voffset,
1108 index_t soffset,
1109 index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v4i16");
1110
1111// buffer load i32
1114 index_t voffset,
1115 index_t soffset,
1116 index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.i32");
1117
1120 index_t voffset,
1121 index_t soffset,
1122 index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v2i32");
1123
1126 index_t voffset,
1127 index_t soffset,
1128 index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v4i32");
1129
1130// buffer load fp16
1131CK_TILE_DEVICE_EXTERN _Float16
1133 index_t voffset,
1134 index_t soffset,
1135 index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.f16");
1136
1139 index_t voffset,
1140 index_t soffset,
1141 index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v2f16");
1142
1145 index_t voffset,
1146 index_t soffset,
1147 index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v4f16");
1148
1149// buffer load fp32
1152 index_t voffset,
1153 index_t soffset,
1154 index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.f32");
1155
1158 index_t voffset,
1159 index_t soffset,
1160 index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v2f32");
1161
1164 index_t voffset,
1165 index_t soffset,
1166 index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v4f32");
1167
1168// buffer store i8
1171 int32x4_t rsrc,
1172 index_t voffset,
1173 index_t soffset,
1174 index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.i8");
1175
1178 int32x4_t rsrc,
1179 index_t voffset,
1180 index_t soffset,
1181 index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v2i8");
1182
1185 int32x4_t rsrc,
1186 index_t voffset,
1187 index_t soffset,
1188 index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v4i8");
1189
1190// buffer store i16
1193 int32x4_t rsrc,
1194 index_t voffset,
1195 index_t soffset,
1196 index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.i16");
1197
1200 int32x4_t rsrc,
1201 index_t voffset,
1202 index_t soffset,
1203 index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v2i16");
1204
1207 int32x4_t rsrc,
1208 index_t voffset,
1209 index_t soffset,
1210 index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v4i16");
1211
1212// buffer store i32
1215 int32x4_t rsrc,
1216 index_t voffset,
1217 index_t soffset,
1218 index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.i32");
1219
1220// buffer store ui16
1223 int32x4_t rsrc,
1224 index_t voffset,
1225 index_t soffset,
1226 index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.i16");
1227
1230 int32x4_t rsrc,
1231 index_t voffset,
1232 index_t soffset,
1233 index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v2i16");
1234
1237 int32x4_t rsrc,
1238 index_t voffset,
1239 index_t soffset,
1240 index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v4i16");
1241
1244 int32x4_t rsrc,
1245 index_t voffset,
1246 index_t soffset,
1247 index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v2i32");
1248
1251 int32x4_t rsrc,
1252 index_t voffset,
1253 index_t soffset,
1254 index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v4i32");
1255
1256// buffer store fp16
1259 int32x4_t rsrc,
1260 index_t voffset,
1261 index_t soffset,
1262 index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.f16");
1263
1266 int32x4_t rsrc,
1267 index_t voffset,
1268 index_t soffset,
1269 index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v2f16");
1270
1273 int32x4_t rsrc,
1274 index_t voffset,
1275 index_t soffset,
1276 index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v4f16");
1277
1278// buffer store fp32
1281 int32x4_t rsrc,
1282 index_t voffset,
1283 index_t soffset,
1284 index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.f32");
1285
1288 int32x4_t rsrc,
1289 index_t voffset,
1290 index_t soffset,
1291 index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v2f32");
1292
1295 int32x4_t rsrc,
1296 index_t voffset,
1297 index_t soffset,
1298 index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v4f32");
1299
1300// buffer atomic-add fp16
1302 fp16x2_t vdata,
1303 int32x4_t rsrc,
1304 index_t voffset,
1305 index_t soffset,
1306 index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.atomic.fadd.v2f16");
1307
1308// buffer atomic-add bf16
1309// TODO: Replace with bf16x2_t, but llvm builins only accept cktile_bf16x2_t now.
1311 bf16x2_t vdata,
1312 int32x4_t rsrc,
1313 index_t voffset,
1314 index_t soffset,
1315 index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.atomic.fadd.v2bf16");
1316
1317// buffer atomic-add i32
1319 int32_t vdata,
1320 int32x4_t rsrc,
1321 index_t voffset,
1322 index_t soffset,
1323 index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.atomic.add.i32");
1324
1325// buffer atomic-add fp32
1327 float vdata,
1328 int32x4_t rsrc,
1329 index_t voffset,
1330 index_t soffset,
1331 index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.atomic.fadd.f32");
1332
1333// buffer atomic-max fp64
1336 int32x4_t rsrc, // dst_wave_buffer_resource
1337 int voffset, // dst_thread_addr_offset
1338 int soffset, // dst_wave_addr_offset
1339 int glc_slc) __asm("llvm.amdgcn.raw.buffer.atomic.fmax.f64");
1340
1341// Direct loads from global to LDS.
1344 as3_uint32_ptr lds_ptr,
1345 index_t size,
1346 index_t voffset,
1347 index_t soffset,
1349 index_t aux) __asm("llvm.amdgcn.raw.buffer.load.lds");
1350
1351template <unsigned num_dwords, bool pre_nop = false>
1353 int32x4_t rsrc,
1354 index_t voffset,
1355 index_t /*soffset*/,
1356 index_t ioffset /*max 0xFFF*/,
1357 index_t /*flag*/ = 0,
1359{
1360#define CK_TILE_ASYNC_LOAD_WITH_INSTR(instr) \
1361 if constexpr(pre_nop) \
1362 asm volatile("s_nop 4\n" instr " %1, %2, 0 offen offset:%3 lds" \
1363 : "=r"(smem) /*dummy dependency for smem*/ \
1364 : "v"(voffset), "s"(rsrc), "n"(ioffset) \
1365 : "memory"); \
1366 else \
1367 asm volatile(instr " %1, %2, 0 offen offset:%3 lds" \
1368 : "=r"(smem) /*dummy dependency for smem*/ \
1369 : "v"(voffset), "s"(rsrc), "n"(ioffset) \
1370 : "memory");
1371
1372 if constexpr(num_dwords == 1)
1373 {
1374 CK_TILE_ASYNC_LOAD_WITH_INSTR("buffer_load_dword");
1375 }
1376#if defined(__gfx950__)
1377 else if constexpr(num_dwords == 3)
1378 {
1379 CK_TILE_ASYNC_LOAD_WITH_INSTR("buffer_load_dwordx3");
1380 }
1381 else if constexpr(num_dwords == 4)
1382 {
1383 CK_TILE_ASYNC_LOAD_WITH_INSTR("buffer_load_dwordx4");
1384 }
1385#endif
1386 else
1387 {
1388 static_assert(false, "wrong! not implemented data width");
1389 }
1390#undef CK_TILE_ASYNC_LOAD_WITH_INSTR
1391}
1392
1394{
1395 asm volatile("s_waitcnt vmcnt(%0)" : : "n"(cnt) : "memory");
1396}
1397
1398// memory coherency bit for buffer store/load instruction
1399// check ISA manual for each GFX target
1400// e.g. for
1401// https://www.amd.com/system/files/TechDocs/instinct-mi200-cdna2-instruction-set-architecture.pdf,
1402// page 67~68
1404{
1405 coherence_default = 0, // default value
1406 glc = 1,
1407 slc = 2,
1409 // gfx94: bit 0 = sc0, bit 1 = nt, bit 3 = swz, bit 4 = sc1
1410 // SC[1:0] System Cache level: 0=wave, 1=group, 2=device, 3=system
1411 // NT Non-Temporal: 0=expect temporal reuse; 1=do not expect temporal reuse
1420};
1421
1422template <index_t N,
1426 index_t src_thread_addr_offset,
1427 index_t src_wave_addr_offset)
1428{
1429 static_assert(N == 1 || N == 2 || N == 4 || N == 8 || N == 16 || N == 32 || N == 64,
1430 "wrong! not implemented");
1431
1432 using rtn_type = thread_buffer<int8_t, N>;
1433
1434 if constexpr(N == 1)
1435 {
1436 return bit_cast<rtn_type>(llvm_amdgcn_raw_buffer_load_i8(src_wave_buffer_resource,
1437 src_thread_addr_offset,
1438 src_wave_addr_offset,
1439 static_cast<index_t>(coherence)));
1440 }
1441 else if constexpr(N == 2)
1442 {
1443
1444 int16_t tmp = llvm_amdgcn_raw_buffer_load_i16(src_wave_buffer_resource,
1445 src_thread_addr_offset,
1446 src_wave_addr_offset,
1447 static_cast<index_t>(coherence));
1448
1449 return bit_cast<rtn_type>(tmp);
1450 }
1451 else if constexpr(N == 4)
1452 {
1453 int32_t tmp = llvm_amdgcn_raw_buffer_load_i32(src_wave_buffer_resource,
1454 src_thread_addr_offset,
1455 src_wave_addr_offset,
1456 static_cast<index_t>(coherence));
1457
1458 return bit_cast<rtn_type>(tmp);
1459 }
1460 else if constexpr(N == 8)
1461 {
1462 int32x2_t tmp = llvm_amdgcn_raw_buffer_load_i32x2(src_wave_buffer_resource,
1463 src_thread_addr_offset,
1464 src_wave_addr_offset,
1465 static_cast<index_t>(coherence));
1466
1467 return bit_cast<rtn_type>(tmp);
1468 }
1469 else if constexpr(N == 16)
1470 {
1471 int32x4_t tmp = llvm_amdgcn_raw_buffer_load_i32x4(src_wave_buffer_resource,
1472 src_thread_addr_offset,
1473 src_wave_addr_offset,
1474 static_cast<index_t>(coherence));
1475 return bit_cast<rtn_type>(tmp);
1476 }
1477 else if constexpr(N == 32)
1478 {
1479 int32x4_t tmp0 = llvm_amdgcn_raw_buffer_load_i32x4(src_wave_buffer_resource,
1480 src_thread_addr_offset,
1481 src_wave_addr_offset,
1482 static_cast<index_t>(coherence));
1483 int32x4_t tmp1 =
1484 llvm_amdgcn_raw_buffer_load_i32x4(src_wave_buffer_resource,
1485 src_thread_addr_offset,
1486 src_wave_addr_offset + 4 * sizeof(int32_t),
1487 static_cast<index_t>(coherence));
1489
1490 tmp.template get_as<int32x4_t>()(number<0>{}) = tmp0;
1491 tmp.template get_as<int32x4_t>()(number<1>{}) = tmp1;
1492
1493 return bit_cast<rtn_type>(tmp);
1494 }
1495 else if constexpr(N == 64)
1496 {
1497 int32x4_t tmp0 = llvm_amdgcn_raw_buffer_load_i32x4(src_wave_buffer_resource,
1498 src_thread_addr_offset,
1499 src_wave_addr_offset,
1500 static_cast<index_t>(coherence));
1501 int32x4_t tmp1 =
1502 llvm_amdgcn_raw_buffer_load_i32x4(src_wave_buffer_resource,
1503 src_thread_addr_offset,
1504 src_wave_addr_offset + 4 * sizeof(int32_t),
1505 static_cast<index_t>(coherence));
1506 int32x4_t tmp2 =
1507 llvm_amdgcn_raw_buffer_load_i32x4(src_wave_buffer_resource,
1508 src_thread_addr_offset,
1509 src_wave_addr_offset + 8 * sizeof(int32_t),
1510 static_cast<index_t>(coherence));
1511 int32x4_t tmp3 =
1512 llvm_amdgcn_raw_buffer_load_i32x4(src_wave_buffer_resource,
1513 src_thread_addr_offset,
1514 src_wave_addr_offset + 12 * sizeof(int32_t),
1515 static_cast<index_t>(coherence));
1516
1518
1519 tmp.template get_as<int32x4_t>()(number<0>{}) = tmp0;
1520 tmp.template get_as<int32x4_t>()(number<1>{}) = tmp1;
1521 tmp.template get_as<int32x4_t>()(number<2>{}) = tmp2;
1522 tmp.template get_as<int32x4_t>()(number<3>{}) = tmp3;
1523
1524 return bit_cast<rtn_type>(tmp);
1525 }
1526}
1527
1528#ifndef BUFFER_LOAD_USE_INLINEASM
1529#define BUFFER_LOAD_USE_INLINEASM 0
1530#endif
1531
1532template <typename T,
1533 index_t N,
1536 index_t src_thread_addr_offset,
1537 index_t src_wave_addr_offset)
1538{
1539 static_assert(
1540 (std::is_same<T, double>::value && (N == 1 || N == 2 || N == 4 || N == 8)) ||
1541 (std::is_same<T, float>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
1542 (std::is_same<T, fp16_t>::value &&
1543 (N == 1 || N == 2 || N == 4 || N == 8 || N == 16 || N == 32)) ||
1544 (std::is_same<T, bf16_t>::value &&
1545 (N == 1 || N == 2 || N == 4 || N == 8 || N == 16 || N == 32)) ||
1546 (std::is_same<T, int32_t>::value &&
1547 (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
1548 (std::is_same<T, fp8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
1549 (std::is_same<T, bf8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
1550 (std::is_same<T, int8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
1551 (std::is_same<T, e8m0_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
1552 (std::is_same<T, pk_int4_t>::value &&
1553 (N == 1 || N == 2 || N == 4 || N == 8 || N == 16 || N == 32) ||
1554 (std::is_same<T, pk_fp4_t>::value &&
1555 (N == 1 || N == 2 || N == 4 || N == 8 || N == 16))),
1556 "wrong! not implemented");
1557
1558 using rtn_type = thread_buffer<T, N>;
1559
1560 if constexpr(std::is_same<T, float>::value) // fp32
1561 {
1562 if constexpr(N == 1)
1563 {
1564 return bit_cast<rtn_type>(
1565 llvm_amdgcn_raw_buffer_load_fp32(src_wave_buffer_resource,
1566 src_thread_addr_offset,
1567 src_wave_addr_offset,
1568 static_cast<index_t>(coherence)));
1569 }
1570 else if constexpr(N == 2)
1571 {
1572 return bit_cast<rtn_type>(
1573 llvm_amdgcn_raw_buffer_load_fp32x2(src_wave_buffer_resource,
1574 src_thread_addr_offset,
1575 src_wave_addr_offset,
1576 static_cast<index_t>(coherence)));
1577 }
1578 else if constexpr(N == 4)
1579 {
1580 return bit_cast<rtn_type>(
1581 llvm_amdgcn_raw_buffer_load_fp32x4(src_wave_buffer_resource,
1582 src_thread_addr_offset,
1583 src_wave_addr_offset,
1584 static_cast<index_t>(coherence)));
1585 }
1586 else if constexpr(N == 8)
1587 {
1589
1590 tmp.template get_as<fp32x4_t>()(number<0>{}) =
1591 llvm_amdgcn_raw_buffer_load_fp32x4(src_wave_buffer_resource,
1592 src_thread_addr_offset,
1593 src_wave_addr_offset,
1594 static_cast<index_t>(coherence));
1595
1596 tmp.template get_as<fp32x4_t>()(number<1>{}) =
1597 llvm_amdgcn_raw_buffer_load_fp32x4(src_wave_buffer_resource,
1598 src_thread_addr_offset,
1599 src_wave_addr_offset + 4 * sizeof(float),
1600 static_cast<index_t>(coherence));
1601
1602 return tmp;
1603 }
1604 else if constexpr(N == 16)
1605 {
1607
1608 tmp.template get_as<fp32x4_t>()(number<0>{}) =
1609 llvm_amdgcn_raw_buffer_load_fp32x4(src_wave_buffer_resource,
1610 src_thread_addr_offset,
1611 src_wave_addr_offset,
1612 static_cast<index_t>(coherence));
1613
1614 tmp.template get_as<fp32x4_t>()(number<1>{}) =
1615 llvm_amdgcn_raw_buffer_load_fp32x4(src_wave_buffer_resource,
1616 src_thread_addr_offset,
1617 src_wave_addr_offset + 4 * sizeof(float),
1618 static_cast<index_t>(coherence));
1619
1620 tmp.template get_as<fp32x4_t>()(number<2>{}) =
1621 llvm_amdgcn_raw_buffer_load_fp32x4(src_wave_buffer_resource,
1622 src_thread_addr_offset,
1623 src_wave_addr_offset + 8 * sizeof(float),
1624 static_cast<index_t>(coherence));
1625
1626 tmp.template get_as<fp32x4_t>()(number<3>{}) =
1627 llvm_amdgcn_raw_buffer_load_fp32x4(src_wave_buffer_resource,
1628 src_thread_addr_offset,
1629 src_wave_addr_offset + 12 * sizeof(float),
1630 static_cast<index_t>(coherence));
1631
1632 return tmp;
1633 }
1634 }
1635 else if constexpr(std::is_same<T, fp16_t>::value) // fp16
1636 {
1637 if constexpr(N == 1)
1638 {
1639 return bit_cast<rtn_type>(
1640 llvm_amdgcn_raw_buffer_load_fp16(src_wave_buffer_resource,
1641 src_thread_addr_offset,
1642 src_wave_addr_offset,
1643 static_cast<index_t>(coherence)));
1644 }
1645 else if constexpr(N == 2)
1646 {
1647 return bit_cast<rtn_type>(
1648 llvm_amdgcn_raw_buffer_load_fp16x2(src_wave_buffer_resource,
1649 src_thread_addr_offset,
1650 src_wave_addr_offset,
1651 static_cast<index_t>(coherence)));
1652 }
1653 else if constexpr(N == 4)
1654 {
1655 return bit_cast<rtn_type>(
1656 llvm_amdgcn_raw_buffer_load_fp16x4(src_wave_buffer_resource,
1657 src_thread_addr_offset,
1658 src_wave_addr_offset,
1659 static_cast<index_t>(coherence)));
1660 }
1661 else if constexpr(N == 8)
1662 {
1663 // use fp32 load to mimic fp16 load
1664 fp32x4_t tmp = llvm_amdgcn_raw_buffer_load_fp32x4(src_wave_buffer_resource,
1665 src_thread_addr_offset,
1666 src_wave_addr_offset,
1667 static_cast<index_t>(coherence));
1668
1669 return bit_cast<rtn_type>(tmp);
1670 }
1671 else if constexpr(N == 16)
1672 {
1674
1675 tmp.template get_as<fp32x4_t>()(number<0>{}) =
1676 llvm_amdgcn_raw_buffer_load_fp32x4(src_wave_buffer_resource,
1677 src_thread_addr_offset,
1678 src_wave_addr_offset,
1679 static_cast<index_t>(coherence));
1680
1681 tmp.template get_as<fp32x4_t>()(number<1>{}) =
1682 llvm_amdgcn_raw_buffer_load_fp32x4(src_wave_buffer_resource,
1683 src_thread_addr_offset,
1684 src_wave_addr_offset + 4 * sizeof(float),
1685 static_cast<index_t>(coherence));
1686
1687 return bit_cast<rtn_type>(tmp);
1688 }
1689 else if constexpr(N == 32)
1690 {
1692
1693 tmp.template get_as<fp32x4_t>()(number<0>{}) =
1694 llvm_amdgcn_raw_buffer_load_fp32x4(src_wave_buffer_resource,
1695 src_thread_addr_offset,
1696 src_wave_addr_offset,
1697 static_cast<index_t>(coherence));
1698
1699 tmp.template get_as<fp32x4_t>()(number<1>{}) =
1700 llvm_amdgcn_raw_buffer_load_fp32x4(src_wave_buffer_resource,
1701 src_thread_addr_offset,
1702 src_wave_addr_offset + 4 * sizeof(float),
1703 static_cast<index_t>(coherence));
1704
1705 tmp.template get_as<fp32x4_t>()(number<2>{}) =
1706 llvm_amdgcn_raw_buffer_load_fp32x4(src_wave_buffer_resource,
1707 src_thread_addr_offset,
1708 src_wave_addr_offset + 8 * sizeof(float),
1709 static_cast<index_t>(coherence));
1710
1711 tmp.template get_as<fp32x4_t>()(number<3>{}) =
1712 llvm_amdgcn_raw_buffer_load_fp32x4(src_wave_buffer_resource,
1713 src_thread_addr_offset,
1714 src_wave_addr_offset + 12 * sizeof(float),
1715 static_cast<index_t>(coherence));
1716
1717 return bit_cast<rtn_type>(tmp);
1718 }
1719 }
1720 else if constexpr(std::is_same<T, bf16_t>::value) // bf16
1721 {
1722 if constexpr(N == 1)
1723 {
1724 return bit_cast<rtn_type>(
1725 llvm_amdgcn_raw_buffer_load_i16(src_wave_buffer_resource,
1726 src_thread_addr_offset,
1727 src_wave_addr_offset,
1728 static_cast<index_t>(coherence)));
1729 }
1730 else if constexpr(N == 2)
1731 {
1732 return bit_cast<rtn_type>(
1733 llvm_amdgcn_raw_buffer_load_i16x2(src_wave_buffer_resource,
1734 src_thread_addr_offset,
1735 src_wave_addr_offset,
1736 static_cast<index_t>(coherence)));
1737 }
1738 else if constexpr(N == 4)
1739 {
1740 return bit_cast<rtn_type>(
1741 llvm_amdgcn_raw_buffer_load_i16x4(src_wave_buffer_resource,
1742 src_thread_addr_offset,
1743 src_wave_addr_offset,
1744 static_cast<index_t>(coherence)));
1745 }
1746 else if constexpr(N == 8)
1747 {
1748 int32x4_t tmp = llvm_amdgcn_raw_buffer_load_i32x4(src_wave_buffer_resource,
1749 src_thread_addr_offset,
1750 src_wave_addr_offset,
1751 static_cast<index_t>(coherence));
1752
1753 return bit_cast<rtn_type>(tmp);
1754 }
1755 else if constexpr(N == 16)
1756 {
1758
1759 tmp.template get_as<fp32x4_t>()(number<0>{}) =
1760 llvm_amdgcn_raw_buffer_load_fp32x4(src_wave_buffer_resource,
1761 src_thread_addr_offset,
1762 src_wave_addr_offset,
1763 static_cast<index_t>(coherence));
1764
1765 tmp.template get_as<fp32x4_t>()(number<1>{}) =
1766 llvm_amdgcn_raw_buffer_load_fp32x4(src_wave_buffer_resource,
1767 src_thread_addr_offset,
1768 src_wave_addr_offset + 4 * sizeof(float),
1769 static_cast<index_t>(coherence));
1770
1771 return bit_cast<rtn_type>(tmp);
1772 }
1773 else if constexpr(N == 32)
1774 {
1776
1777 tmp.template get_as<fp32x4_t>()(number<0>{}) =
1778 llvm_amdgcn_raw_buffer_load_fp32x4(src_wave_buffer_resource,
1779 src_thread_addr_offset,
1780 src_wave_addr_offset,
1781 static_cast<index_t>(coherence));
1782
1783 tmp.template get_as<fp32x4_t>()(number<1>{}) =
1784 llvm_amdgcn_raw_buffer_load_fp32x4(src_wave_buffer_resource,
1785 src_thread_addr_offset,
1786 src_wave_addr_offset + 4 * sizeof(float),
1787 static_cast<index_t>(coherence));
1788
1789 tmp.template get_as<fp32x4_t>()(number<2>{}) =
1790 llvm_amdgcn_raw_buffer_load_fp32x4(src_wave_buffer_resource,
1791 src_thread_addr_offset,
1792 src_wave_addr_offset + 8 * sizeof(float),
1793 static_cast<index_t>(coherence));
1794
1795 tmp.template get_as<fp32x4_t>()(number<3>{}) =
1796 llvm_amdgcn_raw_buffer_load_fp32x4(src_wave_buffer_resource,
1797 src_thread_addr_offset,
1798 src_wave_addr_offset + 12 * sizeof(float),
1799 static_cast<index_t>(coherence));
1800
1801 return bit_cast<rtn_type>(tmp);
1802 }
1803 }
1804 else // other datatype
1805 {
1807 src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset);
1808
1809 return bit_cast<rtn_type>(raw_data);
1810 }
1811}
1812
1813template <typename T,
1814 index_t N,
1816 bool oob_conditional_check = true,
1817 bool pre_nop = false>
1819 int32x4_t src_wave_buffer_resource,
1820 index_t src_thread_addr_offset,
1821 index_t src_wave_addr_offset,
1822 index_t src_linear_addr_offset,
1823 index_t flag = 0,
1825{
1826 constexpr index_t bytes = sizeof(T) * N;
1827 static_assert(bytes == 1 || bytes == 2 || bytes == 4 || bytes == 8 || bytes == 16,
1828 "wrong! not supported by buffer_load instruction");
1829
1830 using type = thread_buffer<T, N>;
1831 if constexpr(oob_conditional_check)
1832 {
1833 buffer_load_if<sizeof(type), pre_nop>{}(dst,
1834 src_wave_buffer_resource,
1835 src_thread_addr_offset,
1836 src_wave_addr_offset,
1837 src_linear_addr_offset,
1838 flag,
1840 }
1841 else
1842 {
1843 buffer_load<sizeof(type), pre_nop>{}(dst,
1844 src_wave_buffer_resource,
1845 src_thread_addr_offset,
1846 src_wave_addr_offset,
1847 src_linear_addr_offset,
1848 flag,
1850 }
1851}
1852
1853template <typename T,
1854 index_t N,
1856 bool pre_nop = false>
1858 int32x4_t src_wave_buffer_resource,
1859 index_t src_thread_addr_offset,
1860 index_t src_wave_addr_offset,
1861 index_t src_immediate_addr_offset = 0,
1863{
1864 constexpr index_t num_bytes = sizeof(T) * N;
1865 constexpr index_t num_words = num_bytes / 4;
1866 static_assert(num_bytes % 4 == 0 && (num_words == 1 || num_words == 3 || num_words == 4),
1867 "wrong! only support in dword, dwordx3, dwordx4");
1868
1870 src_wave_buffer_resource,
1871 src_thread_addr_offset,
1872 src_wave_addr_offset,
1873 src_immediate_addr_offset,
1874 0,
1876}
1877
1878template <typename T,
1879 index_t N,
1881 bool oob_conditional_check = true>
1883 int32x4_t src_wave_buffer_resource,
1884 index_t src_thread_addr_offset,
1885 index_t src_wave_addr_offset,
1886 index_t src_immediate_addr_offset = 0,
1887 index_t flag = 0,
1889{
1890 constexpr index_t bytes = sizeof(T) * N;
1891
1892 // Used to catch the cases when src_immediate_addr_offset is NOT 0.
1893 // Remove this assert once other sizes are implemented.
1894 assert(src_immediate_addr_offset == 0 &&
1895 "wrong! not implemented src_immediate_addr_offset size, only 0 supported");
1896 ignore = src_immediate_addr_offset;
1897
1898#if defined(__gfx950__)
1899 static_assert(bytes == 4 || bytes == 12 || bytes == 16,
1900 "wrong! only support in dword, dwordx3, dwordx4");
1901 src_wave_addr_offset = 0;
1902#else
1903 static_assert(bytes == 4, "wrong! not implemented vector size");
1904#endif
1905
1906 // Set up v_offset:
1907 index_t v_offset = src_thread_addr_offset;
1908 if constexpr(oob_conditional_check)
1909 v_offset = flag ? v_offset : src_wave_buffer_resource[2];
1910
1911#pragma clang diagnostic push
1912#pragma clang diagnostic ignored "-Wold-style-cast"
1913 // Use C-style cast to change address space without dropping llvm noalias attribute
1914 llvm_amdgcn_raw_buffer_load_lds(src_wave_buffer_resource,
1915 (as3_uint32_ptr)(smem),
1916 bytes,
1917 v_offset,
1918 src_wave_addr_offset,
1919 /*src_immediate_addr_offset*/ 0,
1920 static_cast<index_t>(coherence));
1921#pragma clang diagnostic pop
1922}
1923
1924template <index_t N,
1927 int32x4_t dst_wave_buffer_resource,
1928 index_t dst_thread_addr_offset,
1929 index_t dst_wave_addr_offset)
1930{
1931 static_assert(N == 1 || N == 2 || N == 4 || N == 8 || N == 16 || N == 32 || N == 64,
1932 "wrong! not implemented");
1933
1934 if constexpr(N == 1)
1935 {
1937 dst_wave_buffer_resource,
1938 dst_thread_addr_offset,
1939 dst_wave_addr_offset,
1940 static_cast<index_t>(coherence));
1941 }
1942 else if constexpr(N == 2)
1943 {
1944
1946 dst_wave_buffer_resource,
1947 dst_thread_addr_offset,
1948 dst_wave_addr_offset,
1949 static_cast<index_t>(coherence));
1950 }
1951 else if constexpr(N == 4)
1952 {
1954 dst_wave_buffer_resource,
1955 dst_thread_addr_offset,
1956 dst_wave_addr_offset,
1957 static_cast<index_t>(coherence));
1958 }
1959 else if constexpr(N == 8)
1960 {
1962 dst_wave_buffer_resource,
1963 dst_thread_addr_offset,
1964 dst_wave_addr_offset,
1965 static_cast<index_t>(coherence));
1966 }
1967 else if constexpr(N == 16)
1968 {
1970 dst_wave_buffer_resource,
1971 dst_thread_addr_offset,
1972 dst_wave_addr_offset,
1973 static_cast<index_t>(coherence));
1974 }
1975 else if constexpr(N == 32)
1976 {
1978 src_thread_data.template get_as<int32x4_t>()[number<0>{}],
1979 dst_wave_buffer_resource,
1980 dst_thread_addr_offset,
1981 dst_wave_addr_offset,
1982 static_cast<index_t>(coherence));
1983
1985 src_thread_data.template get_as<int32x4_t>()[number<1>{}],
1986 dst_wave_buffer_resource,
1987 dst_thread_addr_offset,
1988 dst_wave_addr_offset + sizeof(int32_t) * 4,
1989 static_cast<index_t>(coherence));
1990 }
1991 else if constexpr(N == 64)
1992 {
1994 src_thread_data.template get_as<int32x4_t>()[number<0>{}],
1995 dst_wave_buffer_resource,
1996 dst_thread_addr_offset,
1997 dst_wave_addr_offset,
1998 static_cast<index_t>(coherence));
1999
2001 src_thread_data.template get_as<int32x4_t>()[number<1>{}],
2002 dst_wave_buffer_resource,
2003 dst_thread_addr_offset,
2004 dst_wave_addr_offset + sizeof(int32_t) * 4,
2005 static_cast<index_t>(coherence));
2006
2008 src_thread_data.template get_as<int32x4_t>()[number<2>{}],
2009 dst_wave_buffer_resource,
2010 dst_thread_addr_offset,
2011 dst_wave_addr_offset + sizeof(int32_t) * 8,
2012 static_cast<index_t>(coherence));
2013
2015 src_thread_data.template get_as<int32x4_t>()[number<3>{}],
2016 dst_wave_buffer_resource,
2017 dst_thread_addr_offset,
2018 dst_wave_addr_offset + sizeof(int32_t) * 12,
2019 static_cast<index_t>(coherence));
2020 }
2021}
2022
2023template <typename T,
2024 index_t N,
2027 int32x4_t dst_wave_buffer_resource,
2028 index_t dst_thread_addr_offset,
2029 index_t dst_wave_addr_offset)
2030{
2031 static_assert(
2032 (std::is_same<T, double>::value && (N == 1 || N == 2 || N == 4 || N == 8)) ||
2033 (std::is_same<T, float>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
2034 (std::is_same<T, fp16_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
2035 (std::is_same<T, bf16_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
2036 (std::is_same<T, int32_t>::value &&
2037 (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
2038 (std::is_same<T, fp8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
2039 (std::is_same<T, bf8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
2040 (std::is_same<T, int8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
2041 (std::is_same<T, uint16_t>::value &&
2042 (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
2043 (std::is_same<T, uint8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)),
2044 "wrong! not implemented");
2045
2046 if constexpr(std::is_same<T, float>::value) // fp32
2047 {
2048 if constexpr(N == 1)
2049 {
2051 dst_wave_buffer_resource,
2052 dst_thread_addr_offset,
2053 dst_wave_addr_offset,
2054 static_cast<index_t>(coherence));
2055 }
2056 else if constexpr(N == 2)
2057 {
2059 dst_wave_buffer_resource,
2060 dst_thread_addr_offset,
2061 dst_wave_addr_offset,
2062 static_cast<index_t>(coherence));
2063 }
2064 else if constexpr(N == 4)
2065 {
2067 dst_wave_buffer_resource,
2068 dst_thread_addr_offset,
2069 dst_wave_addr_offset,
2070 static_cast<index_t>(coherence));
2071 }
2072 else if constexpr(N == 8)
2073 {
2075 src_thread_data.template get_as<fp32x4_t>()[number<0>{}],
2076 dst_wave_buffer_resource,
2077 dst_thread_addr_offset,
2078 dst_wave_addr_offset,
2079 static_cast<index_t>(coherence));
2081 src_thread_data.template get_as<fp32x4_t>()[number<1>{}],
2082 dst_wave_buffer_resource,
2083 dst_thread_addr_offset,
2084 dst_wave_addr_offset + 4 * sizeof(float),
2085 static_cast<index_t>(coherence));
2086 }
2087 }
2088 else if constexpr(std::is_same<T, fp16_t>::value) // fp16
2089 {
2090 if constexpr(N == 1)
2091 {
2093 dst_wave_buffer_resource,
2094 dst_thread_addr_offset,
2095 dst_wave_addr_offset,
2096 static_cast<index_t>(coherence));
2097 }
2098 else if constexpr(N == 2)
2099 {
2101 dst_wave_buffer_resource,
2102 dst_thread_addr_offset,
2103 dst_wave_addr_offset,
2104 static_cast<index_t>(coherence));
2105 }
2106 else if constexpr(N == 4)
2107 {
2109 dst_wave_buffer_resource,
2110 dst_thread_addr_offset,
2111 dst_wave_addr_offset,
2112 static_cast<index_t>(coherence));
2113 }
2114 else if constexpr(N == 8)
2115 {
2116#if 0
2117 thread_buffer<fp16_t, 8> tmp{src_thread_data};
2118
2119 llvm_amdgcn_raw_buffer_store_fp16x4(tmp.template get_as<fp16x4_t>()[number<0>{}],
2120 dst_wave_buffer_resource,
2121 dst_thread_addr_offset,
2122 dst_wave_addr_offset,
2123 static_cast<index_t>(coherence));
2124
2125 llvm_amdgcn_raw_buffer_store_fp16x4(tmp.template get_as<fp16x4_t>()[number<1>{}],
2126 dst_wave_buffer_resource,
2127 dst_thread_addr_offset,
2128 dst_wave_addr_offset + 4 * sizeof(fp16_t),
2129 static_cast<index_t>(coherence));
2130#else
2132 dst_wave_buffer_resource,
2133 dst_thread_addr_offset,
2134 dst_wave_addr_offset,
2135 static_cast<index_t>(coherence));
2136#endif
2137 }
2138 }
2139 else if constexpr(std::is_same<T, bf16_t>::value) // bf16
2140 {
2141 if constexpr(N == 1)
2142 {
2144 dst_wave_buffer_resource,
2145 dst_thread_addr_offset,
2146 dst_wave_addr_offset,
2147 static_cast<index_t>(coherence));
2148 }
2149 else if constexpr(N == 2)
2150 {
2152 dst_wave_buffer_resource,
2153 dst_thread_addr_offset,
2154 dst_wave_addr_offset,
2155 static_cast<index_t>(coherence));
2156 }
2157 else if constexpr(N == 4)
2158 {
2160 dst_wave_buffer_resource,
2161 dst_thread_addr_offset,
2162 dst_wave_addr_offset,
2163 static_cast<index_t>(coherence));
2164 }
2165 else if constexpr(N == 8)
2166 {
2168 src_thread_data.template get_as<int16x4_t>()[number<0>{}],
2169 dst_wave_buffer_resource,
2170 dst_thread_addr_offset,
2171 dst_wave_addr_offset,
2172 static_cast<index_t>(coherence));
2173
2175 src_thread_data.template get_as<int16x4_t>()[number<1>{}],
2176 dst_wave_buffer_resource,
2177 dst_thread_addr_offset,
2178 dst_wave_addr_offset + 4 * sizeof(bf16_t),
2179 static_cast<index_t>(coherence));
2180 }
2181 }
2182 else if constexpr(std::is_same<T, uint16_t>::value)
2183 {
2184 if constexpr(N == 1)
2185 {
2187 dst_wave_buffer_resource,
2188 dst_thread_addr_offset,
2189 dst_wave_addr_offset,
2190 static_cast<index_t>(coherence));
2191 }
2192 else if constexpr(N == 2)
2193 {
2195 dst_wave_buffer_resource,
2196 dst_thread_addr_offset,
2197 dst_wave_addr_offset,
2198 static_cast<index_t>(coherence));
2199 }
2200 else if constexpr(N == 4)
2201 {
2203 dst_wave_buffer_resource,
2204 dst_thread_addr_offset,
2205 dst_wave_addr_offset,
2206 static_cast<index_t>(coherence));
2207 }
2208 else if constexpr(N == 8)
2209 {
2211 src_thread_data.template get_as<uint16x4_t>()[number<0>{}],
2212 dst_wave_buffer_resource,
2213 dst_thread_addr_offset,
2214 dst_wave_addr_offset,
2215 static_cast<index_t>(coherence));
2216
2218 src_thread_data.template get_as<uint16x4_t>()[number<1>{}],
2219 dst_wave_buffer_resource,
2220 dst_thread_addr_offset,
2221 dst_wave_addr_offset + 4 * sizeof(uint16_t),
2222 static_cast<index_t>(coherence));
2223 }
2224 }
2225 else
2226 {
2227 using r_t = thread_buffer<int8_t, sizeof(T) * N>;
2228
2230 dst_wave_buffer_resource,
2231 dst_thread_addr_offset,
2232 dst_wave_addr_offset);
2233 }
2234}
2235
2236template <typename T,
2237 index_t N,
2239 bool oob_conditional_check = true>
2241 int32x4_t dst_wave_buffer_resource,
2242 index_t dst_thread_addr_offset,
2243 index_t dst_wave_addr_offset,
2244 index_t dst_linear_addr_offset,
2245 index_t is_valid_element = 1)
2246{
2247 constexpr index_t bytes = sizeof(T) * N;
2248 static_assert(bytes == 1 || bytes == 2 || bytes == 4 || bytes == 8 || bytes == 16,
2249 "wrong! not supported by buffer_store instruction");
2250
2251 using type = thread_buffer<T, N>;
2252 if constexpr(oob_conditional_check)
2253 {
2254 buffer_store_if<sizeof(type)>{}(dst_thread_data,
2255 dst_wave_buffer_resource,
2256 dst_thread_addr_offset,
2257 dst_wave_addr_offset,
2258 dst_linear_addr_offset,
2259 is_valid_element);
2260 }
2261 else
2262 {
2263 buffer_store<sizeof(type)>{}(dst_thread_data,
2264 dst_wave_buffer_resource,
2265 dst_thread_addr_offset,
2266 dst_wave_addr_offset,
2267 dst_linear_addr_offset);
2268 }
2269}
2270
2271template <typename T, index_t N>
2273 int32x4_t dst_wave_buffer_resource,
2274 index_t dst_thread_addr_offset,
2275 index_t dst_wave_addr_offset)
2276{
2277 static_assert((std::is_same<T, float>::value && (N == 1 || N == 2 || N == 4)) ||
2278 (std::is_same<T, fp16_t>::value && (N == 2 || N == 4 || N == 8)) ||
2279 (std::is_same<T, bf16_t>::value && (N == 2 || N == 4 || N == 8)) ||
2280 (std::is_same<T, int32_t>::value && (N == 1 || N == 2 || N == 4)),
2281 "wrong! not implemented");
2282
2283 if constexpr(std::is_same<T, float>::value)
2284 {
2285 if constexpr(N == 1)
2286 {
2288 dst_wave_buffer_resource,
2289 dst_thread_addr_offset,
2290 dst_wave_addr_offset,
2291 0);
2292 }
2293 else if constexpr(N == 2)
2294 {
2296 src_thread_data.template get_as<float>()[number<0>{}],
2297 dst_wave_buffer_resource,
2298 dst_thread_addr_offset,
2299 dst_wave_addr_offset,
2300 0);
2301
2303 src_thread_data.template get_as<float>()[number<1>{}],
2304 dst_wave_buffer_resource,
2305 dst_thread_addr_offset,
2306 dst_wave_addr_offset + sizeof(float),
2307 0);
2308 }
2309 else if constexpr(N == 4)
2310 {
2312 src_thread_data.template get_as<float>()[number<0>{}],
2313 dst_wave_buffer_resource,
2314 dst_thread_addr_offset,
2315 dst_wave_addr_offset,
2316 0);
2317
2319 src_thread_data.template get_as<float>()[number<1>{}],
2320 dst_wave_buffer_resource,
2321 dst_thread_addr_offset,
2322 dst_wave_addr_offset + sizeof(float),
2323 0);
2324
2326 src_thread_data.template get_as<float>()[number<2>{}],
2327 dst_wave_buffer_resource,
2328 dst_thread_addr_offset,
2329 dst_wave_addr_offset + 2 * sizeof(float),
2330 0);
2331
2333 src_thread_data.template get_as<float>()[number<3>{}],
2334 dst_wave_buffer_resource,
2335 dst_thread_addr_offset,
2336 dst_wave_addr_offset + 3 * sizeof(float),
2337 0);
2338 }
2339 }
2340 else if constexpr(std::is_same<T, fp16_t>::value)
2341 {
2342 if constexpr(N == 2)
2343 {
2345 dst_wave_buffer_resource,
2346 dst_thread_addr_offset,
2347 dst_wave_addr_offset,
2348 0);
2349 }
2350 else if constexpr(N == 4)
2351 {
2352 static_for<0, 2, 1>{}([&](auto i) {
2354 src_thread_data.template get_as<fp16x2_t>()[i],
2355 dst_wave_buffer_resource,
2356 dst_thread_addr_offset,
2357 dst_wave_addr_offset + i * sizeof(fp16x2_t),
2358 0);
2359 });
2360 }
2361 else if constexpr(N == 8)
2362 {
2363 static_for<0, 4, 1>{}([&](auto i) {
2365 src_thread_data.template get_as<fp16x2_t>()[i],
2366 dst_wave_buffer_resource,
2367 dst_thread_addr_offset,
2368 dst_wave_addr_offset + i * sizeof(fp16x2_t),
2369 0);
2370 });
2371 }
2372 }
2373 else if constexpr(std::is_same<T, bf16_t>::value)
2374 {
2375 if constexpr(N == 2)
2376 {
2378 dst_wave_buffer_resource,
2379 dst_thread_addr_offset,
2380 dst_wave_addr_offset,
2381 0);
2382 }
2383 else if constexpr(N == 4)
2384 {
2385 static_for<0, 2, 1>{}([&](auto i) {
2387 src_thread_data.template get_as<bf16x2_t>()[i],
2388 dst_wave_buffer_resource,
2389 dst_thread_addr_offset,
2390 dst_wave_addr_offset + i * sizeof(bf16x2_t),
2391 0);
2392 });
2393 }
2394 else if constexpr(N == 8)
2395 {
2396 static_for<0, 4, 1>{}([&](auto i) {
2398 src_thread_data.template get_as<bf16x2_t>()[i],
2399 dst_wave_buffer_resource,
2400 dst_thread_addr_offset,
2401 dst_wave_addr_offset + i * sizeof(bf16x2_t),
2402 0);
2403 });
2404 }
2405 }
2406 else if constexpr(std::is_same<T, int32_t>::value)
2407 {
2408 if constexpr(N == 1)
2409 {
2411 dst_wave_buffer_resource,
2412 dst_thread_addr_offset,
2413 dst_wave_addr_offset,
2414 0);
2415 }
2416 else if constexpr(N == 2)
2417 {
2419 src_thread_data.template get_as<int32_t>()[number<0>{}],
2420 dst_wave_buffer_resource,
2421 dst_thread_addr_offset,
2422 dst_wave_addr_offset,
2423 0);
2424
2426 src_thread_data.template get_as<int32_t>()[number<1>{}],
2427 dst_wave_buffer_resource,
2428 dst_thread_addr_offset,
2429 dst_wave_addr_offset + sizeof(int32_t),
2430 0);
2431 }
2432 else if constexpr(N == 4)
2433 {
2435 src_thread_data.template get_as<int32_t>()[number<0>{}],
2436 dst_wave_buffer_resource,
2437 dst_thread_addr_offset,
2438 dst_wave_addr_offset,
2439 0);
2440
2442 src_thread_data.template get_as<int32_t>()[number<1>{}],
2443 dst_wave_buffer_resource,
2444 dst_thread_addr_offset,
2445 dst_wave_addr_offset + sizeof(int32_t),
2446 0);
2447
2449 src_thread_data.template get_as<int32_t>()[number<2>{}],
2450 dst_wave_buffer_resource,
2451 dst_thread_addr_offset,
2452 dst_wave_addr_offset + 2 * sizeof(int32_t),
2453 0);
2454
2456 src_thread_data.template get_as<int32_t>()[number<3>{}],
2457 dst_wave_buffer_resource,
2458 dst_thread_addr_offset,
2459 dst_wave_addr_offset + 3 * sizeof(int32_t),
2460 0);
2461 }
2462 }
2463}
2464
2465template <typename T, index_t N>
2467 int32x4_t dst_wave_buffer_resource,
2468 index_t dst_thread_addr_offset,
2469 index_t dst_wave_addr_offset)
2470{
2471 static_assert((std::is_same<T, double>::value && (N == 1 || N == 2 || N == 4)),
2472 "wrong! not implemented");
2473 if constexpr(std::is_same<T, double>::value)
2474 {
2475 if constexpr(N == 1)
2476 {
2478 dst_wave_buffer_resource,
2479 dst_thread_addr_offset,
2480 dst_wave_addr_offset,
2481 0);
2482 }
2483 else if constexpr(N == 2)
2484 {
2486 src_thread_data.template get_as<double>()[number<0>{}],
2487 dst_wave_buffer_resource,
2488 dst_thread_addr_offset,
2489 dst_wave_addr_offset,
2490 0);
2491
2493 src_thread_data.template get_as<double>()[number<1>{}],
2494 dst_wave_buffer_resource,
2495 dst_thread_addr_offset,
2496 dst_wave_addr_offset + sizeof(double),
2497 0);
2498 }
2499 else if constexpr(N == 4)
2500 {
2502 src_thread_data.template get_as<double>()[number<0>{}],
2503 dst_wave_buffer_resource,
2504 dst_thread_addr_offset,
2505 dst_wave_addr_offset,
2506 0);
2507
2509 src_thread_data.template get_as<double>()[number<1>{}],
2510 dst_wave_buffer_resource,
2511 dst_thread_addr_offset,
2512 dst_wave_addr_offset + sizeof(double),
2513 0);
2514
2516 src_thread_data.template get_as<double>()[number<2>{}],
2517 dst_wave_buffer_resource,
2518 dst_thread_addr_offset,
2519 dst_wave_addr_offset + 2 * sizeof(double),
2520 0);
2521
2523 src_thread_data.template get_as<double>()[number<3>{}],
2524 dst_wave_buffer_resource,
2525 dst_thread_addr_offset,
2526 dst_wave_addr_offset + 3 * sizeof(double),
2527 0);
2528 }
2529 }
2530}
2531
2532// buffer_load requires:
2533// 1) p_src_wave must point to global memory space
2534// 2) p_src_wave must be a wavewise pointer.
2535// It is user's responsibility to make sure that is true.
2536// oob_conditional_check : dynamic check if out-of-bound
2537template <typename T,
2538 index_t N,
2540 bool oob_conditional_check = true>
2543 index_t src_thread_element_offset,
2544 bool src_thread_element_valid,
2545 index_t src_element_space_size)
2546{
2547 const int32x4_t src_wave_buffer_resource =
2548 make_wave_buffer_resource(p_src_wave, src_element_space_size * sizeof(T));
2549
2550 index_t src_thread_addr_offset = src_thread_element_offset * sizeof(T);
2551
2552#if CK_TILE_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK
2553 uint32_t src_addr_shift = [&]() {
2554 if constexpr(oob_conditional_check)
2555 return src_thread_element_valid ? 0 : 0x80000000;
2556 else
2557 return 0;
2558 }();
2560 src_wave_buffer_resource, src_addr_shift + src_thread_addr_offset, 0);
2561#else
2563 amd_buffer_load_impl<T, N, coherence>(src_wave_buffer_resource, src_thread_addr_offset, 0);
2564 if constexpr(oob_conditional_check)
2565 return src_thread_element_valid ? tmp : thread_buffer<T, N>{numeric<T>::zero()};
2566 else
2567 return tmp;
2568#endif
2569}
2570
2571// buffer_load requires:
2572// 1) p_src_wave must point to global memory space
2573// 2) p_src_wave must be a wavewise pointer.
2574// It is user's responsibility to make sure that is true.
2575template <typename T,
2576 index_t N,
2578 bool oob_conditional_check = true>
2581 index_t src_thread_element_offset,
2582 bool src_thread_element_valid,
2583 index_t src_element_space_size,
2584 T customized_value)
2585{
2586 const int32x4_t src_wave_buffer_resource =
2587 make_wave_buffer_resource(p_src_wave, src_element_space_size * sizeof(T));
2588
2589 index_t src_thread_addr_offset = src_thread_element_offset * sizeof(T);
2590
2592 amd_buffer_load_impl<T, N, coherence>(src_wave_buffer_resource, src_thread_addr_offset, 0);
2593
2594 if constexpr(oob_conditional_check)
2595 return src_thread_element_valid ? tmp : thread_buffer<T, N>{customized_value};
2596 else
2597 return tmp;
2598}
2599
2600template <typename T,
2601 index_t N,
2603 bool oob_conditional_check = true,
2604 bool pre_nop = false>
2606 const T* p_src_wave,
2607 index_t src_thread_element_offset,
2608 index_t src_linear_element_offset,
2609 index_t src_element_space_size,
2610 index_t is_valid_element = 0,
2612{
2613 const int32x4_t src_wave_buffer_resource =
2614 make_wave_buffer_resource(p_src_wave, src_element_space_size * sizeof(T));
2615
2616 index_t src_thread_addr_offset = src_thread_element_offset * sizeof(T);
2617 index_t src_linear_addr_offset = src_linear_element_offset * sizeof(T);
2618
2620 dst,
2621 src_wave_buffer_resource,
2622 src_thread_addr_offset,
2623 0,
2624 src_linear_addr_offset,
2625 is_valid_element,
2627}
2628
2629// This version support buffer resource as input arg
2630template <typename T,
2631 index_t N,
2633 bool oob_conditional_check = true,
2634 bool pre_nop = false>
2636 const int32x4_t src_wave_buffer_resource,
2637 index_t src_thread_element_offset,
2638 index_t src_linear_element_offset,
2639 index_t is_valid_element = 0,
2641{
2642 index_t src_thread_addr_offset = src_thread_element_offset * sizeof(T);
2643 index_t src_linear_addr_offset = src_linear_element_offset * sizeof(T);
2644
2646 dst,
2647 src_wave_buffer_resource,
2648 src_thread_addr_offset,
2649 0,
2650 src_linear_addr_offset,
2651 is_valid_element,
2653}
2654
2655// unfortunately async copy can not make sure invalid data is zero inside LDS
2656// ... unless people manually write zero to LDS at the proper address.
2657// so not support invalid_element check for now.
2658// buffer_load OOB still working.
2659template <typename T,
2660 index_t N,
2662 bool pre_nop = false>
2664 const T* p_src_wave,
2665 index_t src_thread_element_offset,
2666 index_t src_linear_element_offset,
2667 index_t src_element_space_size,
2669{
2670 const int32x4_t src_wave_buffer_resource =
2671 make_wave_buffer_resource(p_src_wave, src_element_space_size * sizeof(T));
2672
2673 index_t src_thread_addr_offset = src_thread_element_offset * sizeof(T);
2674 index_t src_linear_addr_offset = src_linear_element_offset * sizeof(T);
2675
2677 src_wave_buffer_resource,
2678 src_thread_addr_offset,
2679 0,
2680 src_linear_addr_offset,
2682}
2683
2684// This version support buffer resource as input arg
2685template <typename T,
2686 index_t N,
2688 bool pre_nop = false>
2690 const int32x4_t src_wave_buffer_resource,
2691 index_t src_thread_element_offset,
2692 index_t src_linear_element_offset,
2694{
2695 index_t src_thread_addr_offset = src_thread_element_offset * sizeof(T);
2696 index_t src_linear_addr_offset = src_linear_element_offset * sizeof(T);
2697
2699 src_wave_buffer_resource,
2700 src_thread_addr_offset,
2701 0,
2702 src_linear_addr_offset,
2704}
2705
2706// This version support buffer resource as input arg
2707template <typename T,
2708 index_t N,
2710 bool oob_conditional_check = false>
2712 const int32x4_t src_wave_buffer_resource,
2713 index_t src_thread_element_offset,
2714 index_t src_linear_element_offset,
2715 bool is_valid_element,
2717{
2718 index_t src_thread_addr_offset = src_thread_element_offset * sizeof(T);
2719 index_t src_linear_addr_offset = src_linear_element_offset * sizeof(T);
2720
2722 src_wave_buffer_resource,
2723 src_thread_addr_offset,
2724 0,
2725 src_linear_addr_offset,
2726 is_valid_element,
2728}
2729
2730// buffer_store requires:
2731// 1) p_dst_wave must point to global memory
2732// 2) p_dst_wave must be a wavewise pointer.
2733// It is user's responsibility to make sure that is true.
2734template <typename T,
2735 index_t N,
2737 bool oob_conditional_check = true>
2739 T* p_dst_wave,
2740 const index_t dst_thread_element_offset,
2741 const bool dst_thread_element_valid,
2742 const index_t dst_element_space_size)
2743{
2744 const int32x4_t dst_wave_buffer_resource =
2745 make_wave_buffer_resource(p_dst_wave, dst_element_space_size * sizeof(T));
2746
2747 index_t dst_thread_addr_offset = dst_thread_element_offset * sizeof(T);
2748
2749#if CK_TILE_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK
2750 uint32_t dst_addr_shift = [&]() {
2751 if constexpr(oob_conditional_check)
2752 return dst_thread_element_valid ? 0 : 0x80000000;
2753 else
2754 return 0;
2755 }();
2757 src_thread_data, dst_wave_buffer_resource, dst_addr_shift + dst_thread_addr_offset, 0);
2758#else
2759 if constexpr(oob_conditional_check)
2760 {
2761 if(dst_thread_element_valid)
2762 {
2764 src_thread_data, dst_wave_buffer_resource, dst_thread_addr_offset, 0);
2765 }
2766 }
2767 else
2768 {
2770 src_thread_data, dst_wave_buffer_resource, dst_thread_addr_offset, 0);
2771 }
2772#endif
2773}
2774
2775template <typename T,
2776 index_t N,
2778 bool oob_conditional_check = true>
2780 T* p_dst_wave,
2781 const index_t dst_thread_element_offset,
2782 const index_t dst_linear_element_offset,
2783 const bool dst_thread_element_valid,
2784 const index_t dst_element_space_size)
2785{
2786 const int32x4_t dst_wave_buffer_resource =
2787 make_wave_buffer_resource(p_dst_wave, dst_element_space_size * sizeof(T));
2788
2789 index_t dst_thread_addr_offset = dst_thread_element_offset * sizeof(T);
2790 index_t dst_linear_addr_offset = dst_linear_element_offset * sizeof(T);
2791
2793 dst_wave_buffer_resource,
2794 dst_thread_addr_offset,
2795 0,
2796 dst_linear_addr_offset,
2797 dst_thread_element_valid);
2798}
2799
2800// buffer_atomic_add requires:
2801// 1) p_dst_wave must point to global memory
2802// 2) p_dst_wave must be a wavewise pointer.
2803// It is user's responsibility to make sure that is true.
2804template <typename T, index_t N>
2806 T* p_dst_wave,
2807 const index_t dst_thread_element_offset,
2808 const bool dst_thread_element_valid,
2809 const index_t dst_element_space_size)
2810{
2811 const int32x4_t dst_wave_buffer_resource =
2812 make_wave_buffer_resource(p_dst_wave, dst_element_space_size * sizeof(T));
2813
2814 index_t dst_thread_addr_offset = dst_thread_element_offset * sizeof(T);
2815
2816#if CK_TILE_EXPERIMENTAL_USE_BUFFER_ATOMIC_ADD_OOB_CHECK_OFFSET_TRICK
2817 uint32_t dst_addr_shift = dst_thread_element_valid ? 0 : 0x80000000;
2818
2820 src_thread_data, dst_wave_buffer_resource, dst_addr_shift + dst_thread_addr_offset, 0);
2821#else
2822 if(dst_thread_element_valid)
2823 {
2825 src_thread_data, dst_wave_buffer_resource, dst_thread_addr_offset, 0);
2826 }
2827#endif
2828}
2829
2830template <typename T,
2831 index_t N,
2833 bool oob_conditional_check = true,
2834 bool pre_nop = false>
2836 T* p_dst_wave,
2837 const index_t dst_thread_element_offset,
2838 const index_t dst_linear_element_offset,
2839 const bool dst_thread_element_valid,
2840 const index_t dst_element_space_size,
2842{
2843 const int32x4_t dst_wave_buffer_resource =
2844 make_wave_buffer_resource(p_dst_wave, dst_element_space_size * sizeof(T));
2845
2846 index_t dst_thread_addr_offset = dst_thread_element_offset * sizeof(T);
2847 index_t dst_linear_addr_offset = dst_linear_element_offset * sizeof(T);
2848
2849 if constexpr(oob_conditional_check)
2850 {
2851 buffer_atomic_add_if<T, N, pre_nop>{}(src_thread_data,
2852 dst_wave_buffer_resource,
2853 dst_thread_addr_offset,
2854 0,
2855 dst_linear_addr_offset,
2856 dst_thread_element_valid);
2857 }
2858 else
2859 {
2860 buffer_atomic_add<T, N, pre_nop>{}(src_thread_data,
2861 dst_wave_buffer_resource,
2862 dst_thread_addr_offset,
2863 0,
2864 dst_linear_addr_offset,
2865 1);
2866 }
2867}
2868
2869// buffer_atomic_max requires:
2870// 1) p_dst_wave must point to global memory
2871// 2) p_dst_wave must be a wavewise pointer.
2872// It is user's responsibility to make sure that is true.
2873template <typename T, index_t N>
2875 T* p_dst_wave,
2876 const index_t dst_thread_element_offset,
2877 const bool dst_thread_element_valid,
2878 const index_t dst_element_space_size)
2879{
2880 const int32x4_t dst_wave_buffer_resource =
2881 make_wave_buffer_resource(p_dst_wave, dst_element_space_size * sizeof(T));
2882
2883 index_t dst_thread_addr_offset = dst_thread_element_offset * sizeof(T);
2884
2885#if CK_TILE_EXPERIMENTAL_USE_BUFFER_ATOMIC_MAX_OOB_CHECK_OFFSET_TRICK
2886 uint32_t dst_addr_shift = dst_thread_element_valid ? 0 : 0x80000000;
2887
2889 src_thread_data, dst_wave_buffer_resource, dst_addr_shift + dst_thread_addr_offset, 0);
2890#else
2891 if(dst_thread_element_valid)
2892 {
2894 src_thread_data, dst_wave_buffer_resource, dst_thread_addr_offset, 0);
2895 }
2896#endif
2897}
2898
2899#if defined(__gfx950__)
2900template <typename T, index_t N>
2901__device__ auto amd_transpose_load_to_vgpr(const T* __restrict__ in_ptr)
2902{
2903#define __LDS_ADDR __attribute__((address_space(3)))
2904
2905 static_assert(__has_builtin(__builtin_amdgcn_raw_buffer_load_b32),
2906 "We need to have the compatible compiler version to build this instruction");
2907
2908#pragma clang diagnostic push
2909#pragma clang diagnostic ignored "-Wold-style-cast"
2910 // Use C-style cast to change address space without dropping llvm noalias attribute
2911 const auto in_ptr_ = (__LDS_ADDR T*)(const_cast<T*>(in_ptr));
2912#pragma clang diagnostic pop
2913 if constexpr(std::is_same_v<remove_cvref_t<T>, ck_tile::half_t>)
2914 {
2915 typedef __attribute__((__vector_size__(4 * sizeof(__fp16)))) __fp16 llvm_fp16x4_t;
2916 auto lds_ptr = reinterpret_cast<__LDS_ADDR llvm_fp16x4_t*>(in_ptr_);
2917 return bit_cast<thread_buffer<T, N>>(__builtin_amdgcn_ds_read_tr16_b64_v4f16(lds_ptr));
2918 }
2919 else if constexpr(std::is_same_v<remove_cvref_t<T>, ck_tile::bf16_t>)
2920 {
2921 typedef __attribute__((__vector_size__(4 * sizeof(__bf16)))) __bf16 llvm_bf16x4_t;
2922 auto lds_ptr = reinterpret_cast<__LDS_ADDR llvm_bf16x4_t*>(in_ptr_);
2923 return bit_cast<thread_buffer<T, N>>(__builtin_amdgcn_ds_read_tr16_b64_v4bf16(lds_ptr));
2924 }
2925 else if constexpr(std::is_same_v<remove_cvref_t<T>, ck_tile::fp8_t> ||
2926 std::is_same_v<remove_cvref_t<T>, ck_tile::bf8_t> ||
2927 std::is_same_v<remove_cvref_t<T>, ck_tile::int8_t>)
2928 {
2929 typedef __attribute__((__vector_size__(2 * sizeof(index_t)))) index_t llvm_i32x2_t;
2930 auto lds_ptr = reinterpret_cast<__LDS_ADDR llvm_i32x2_t*>(in_ptr_);
2931 return bit_cast<thread_buffer<T, N>>(__builtin_amdgcn_ds_read_tr8_b64_v2i32(lds_ptr));
2932 }
2933 else
2934 {
2935 static_assert(false, "not implemented");
2936 }
2937#undef __LDS_ADDR
2938}
2939#endif
2940
2941} // namespace ck_tile
2942
2943#endif // !CK_TILE_USE_BUFFER_ADDRESSING_BUILTIN
#define CK_TILE_DEVICE
Definition config.hpp:41
#define CK_TILE_BUFFER_RESOURCE_3RD_DWORD
Definition config.hpp:210
#define CK_TILE_DEVICE_EXTERN
Definition config.hpp:43
#define CK_TILE_LDS_ADDR
Definition config.hpp:58
Definition tile/core/arch/amd_buffer_addressing.hpp:110
CK_TILE_DEVICE void insert_dummy_dep_per_dword< 3 >(array< float, 3 > &b)
Definition tile/core/arch/amd_buffer_addressing.hpp:997
CK_TILE_DEVICE void insert_dummy_dep()
Definition tile/core/arch/amd_buffer_addressing.hpp:1037
CK_TILE_DEVICE void insert_dummy_dep_per_dword< 32 >(array< float, 32 > &b)
Definition tile/core/arch/amd_buffer_addressing.hpp:1025
CK_TILE_DEVICE void insert_dummy_dep_per_dword< 8 >(array< float, 8 > &b)
Definition tile/core/arch/amd_buffer_addressing.hpp:1009
CK_TILE_DEVICE void insert_dummy_dep_per_dword(array< float, N > &b)
Definition tile/core/arch/amd_buffer_addressing.hpp:981
CK_TILE_DEVICE void insert_dummy_dep_per_dword< 2 >(array< float, 2 > &b)
Definition tile/core/arch/amd_buffer_addressing.hpp:991
CK_TILE_DEVICE void insert_dummy_dep_per_dword< 4 >(array< float, 4 > &b)
Definition tile/core/arch/amd_buffer_addressing.hpp:1003
CK_TILE_DEVICE void insert_dummy_dep_per_dword< 16 >(array< float, 16 > &b)
Definition tile/core/arch/amd_buffer_addressing.hpp:1016
Definition tile/core/algorithm/cluster_descriptor.hpp:13
CK_TILE_DEVICE_EXTERN int8x2_t llvm_amdgcn_raw_buffer_load_i8x2(int32x4_t srsrc, index_t voffset, index_t soffset, index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v2i8")
_Float16 fp16x2_t
Definition half.hpp:385
CK_TILE_DEVICE thread_buffer< T, N > amd_buffer_load_impl(int32x4_t src_wave_buffer_resource, index_t src_thread_addr_offset, index_t src_wave_addr_offset)
Definition tile/core/arch/amd_buffer_addressing.hpp:1535
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition type_traits.hpp:21
int8_t int8x2_t
Definition pk_int4.hpp:103
CK_TILE_DEVICE void amd_buffer_atomic_max(const thread_buffer< T, N > &src_thread_data, T *p_dst_wave, const index_t dst_thread_element_offset, const bool dst_thread_element_valid, const index_t dst_element_space_size)
Definition tile/core/arch/amd_buffer_addressing.hpp:2874
CK_TILE_DEVICE_EXTERN void llvm_amdgcn_raw_buffer_store_fp16(_Float16 vdata, int32x4_t rsrc, index_t voffset, index_t soffset, index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.f16")
__device__ uint32_t amd_wave_read_first_lane(uint16_t v)
Definition tile/core/arch/amd_buffer_addressing.hpp:35
_Float16 half_t
Definition half.hpp:111
CK_TILE_DEVICE_EXTERN fp16x4_t llvm_amdgcn_raw_buffer_load_fp16x4(int32x4_t srsrc, index_t voffset, index_t soffset, index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v4f16")
uint16_t uint16x2_t
Definition vector_type.hpp:181
int16_t int16x4_t
Definition vector_type.hpp:173
CK_TILE_DEVICE_EXTERN void llvm_amdgcn_raw_buffer_store_i8x4(int8x4_t vdata, int32x4_t rsrc, index_t voffset, index_t soffset, index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v4i8")
CK_TILE_DEVICE_EXTERN int8_t llvm_amdgcn_raw_buffer_load_i8(int32x4_t srsrc, index_t voffset, index_t soffset, index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.i8")
CK_TILE_DEVICE_EXTERN void llvm_amdgcn_raw_buffer_store_i8(int8_t vdata, int32x4_t rsrc, index_t voffset, index_t soffset, index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.i8")
constant< b > bool_constant
Definition tile/core/numeric/integral_constant.hpp:43
int8_t int8_t
Definition int8.hpp:20
CK_TILE_DEVICE void amd_buffer_store(const thread_buffer< T, N > &src_thread_data, T *p_dst_wave, const index_t dst_thread_element_offset, const bool dst_thread_element_valid, const index_t dst_element_space_size)
Definition tile/core/arch/amd_buffer_addressing.hpp:2738
CK_TILE_DEVICE_EXTERN void llvm_amdgcn_raw_buffer_store_i16x2(int16x2_t vdata, int32x4_t rsrc, index_t voffset, index_t soffset, index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v2i16")
bfloat16_t bf16_t
Definition bfloat16.hpp:113
CK_TILE_DEVICE auto async_load_fence_raw(index_t cnt=0)
Definition tile/core/arch/amd_buffer_addressing.hpp:1068
CK_TILE_DEVICE_EXTERN void llvm_amdgcn_raw_buffer_store_i16x4(int16x4_t vdata, int32x4_t rsrc, index_t voffset, index_t soffset, index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v4i16")
_Float16 fp16_t
Definition half.hpp:110
amd_buffer_coherence_enum
Definition tile/core/arch/amd_buffer_addressing.hpp:1404
@ glc_slc
Definition tile/core/arch/amd_buffer_addressing.hpp:1408
@ SYSTEM_NT1
Definition tile/core/arch/amd_buffer_addressing.hpp:1419
@ coherence_default
Definition tile/core/arch/amd_buffer_addressing.hpp:1405
@ WAVE_NT0
Definition tile/core/arch/amd_buffer_addressing.hpp:1412
@ slc
Definition tile/core/arch/amd_buffer_addressing.hpp:1407
@ DEVICE_NT1
Definition tile/core/arch/amd_buffer_addressing.hpp:1417
@ SYSTEM_NT0
Definition tile/core/arch/amd_buffer_addressing.hpp:1418
@ glc
Definition tile/core/arch/amd_buffer_addressing.hpp:1406
@ GROUP_NT1
Definition tile/core/arch/amd_buffer_addressing.hpp:1415
@ DEVICE_NT0
Definition tile/core/arch/amd_buffer_addressing.hpp:1416
@ GROUP_NT0
Definition tile/core/arch/amd_buffer_addressing.hpp:1414
@ WAVE_NT1
Definition tile/core/arch/amd_buffer_addressing.hpp:1413
CK_TILE_DEVICE_EXTERN double llvm_amdgcn_raw_buffer_atomic_max_fp64(double vdata, int32x4_t rsrc, int voffset, int soffset, int glc_slc) __asm("llvm.amdgcn.raw.buffer.atomic.fmax.f64")
CK_TILE_DEVICE_EXTERN float llvm_amdgcn_raw_buffer_atomic_add_fp32(float vdata, int32x4_t rsrc, index_t voffset, index_t soffset, index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.atomic.fadd.f32")
_BitInt(8) fp8_t
Definition float8.hpp:204
CK_TILE_DEVICE_EXTERN fp32x2_t llvm_amdgcn_raw_buffer_load_fp32x2(int32x4_t srsrc, index_t voffset, index_t soffset, index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v2f32")
CK_TILE_DEVICE void amd_async_buffer_load_with_oob(CK_TILE_LDS_ADDR T *smem, const int32x4_t src_wave_buffer_resource, index_t src_thread_element_offset, index_t src_linear_element_offset, bool is_valid_element, bool_constant< oob_conditional_check >={})
Definition tile/core/arch/amd_buffer_addressing.hpp:2711
tuple_array< T, N > thread_buffer
Definition thread_buffer.hpp:14
int32_t int32x4_t
Definition vector_type.hpp:155
CK_TILE_DEVICE void amd_async_buffer_load_with_oob_raw(T *smem, const T *p_src_wave, index_t src_thread_element_offset, index_t src_linear_element_offset, index_t src_element_space_size, bool_constant< pre_nop >={})
Definition tile/core/arch/amd_buffer_addressing.hpp:2663
bfloat16_t bf16x2_t
Definition pk_fp4.hpp:24
CK_TILE_HOST_DEVICE constexpr Y bit_cast(const X &x)
Definition bit_cast.hpp:11
CK_TILE_DEVICE_EXTERN void llvm_amdgcn_raw_buffer_store_fp16x2(fp16x2_t vdata, int32x4_t rsrc, index_t voffset, index_t soffset, index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v2f16")
CK_TILE_DEVICE_EXTERN void llvm_amdgcn_raw_buffer_store_fp32x4(fp32x4_t vdata, int32x4_t rsrc, index_t voffset, index_t soffset, index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v4f32")
CK_TILE_DEVICE_EXTERN void llvm_amdgcn_raw_buffer_store_fp16x4(fp16x4_t vdata, int32x4_t rsrc, index_t voffset, index_t soffset, index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v4f16")
CK_TILE_DEVICE_EXTERN int32x4_t llvm_amdgcn_raw_buffer_load_i32x4(int32x4_t srsrc, index_t voffset, index_t soffset, index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v4i32")
CK_TILE_DEVICE void lds_load_fence(index_t cnt=0)
Definition tile/core/arch/amd_buffer_addressing.hpp:820
CK_TILE_DEVICE_EXTERN void llvm_amdgcn_raw_buffer_store_i8x2(int8x2_t vdata, int32x4_t rsrc, index_t voffset, index_t soffset, index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v2i8")
uint32_t uint32x4_t
Definition vector_type.hpp:164
CK_TILE_DEVICE void amd_buffer_store_impl_with_bytes(const thread_buffer< int8_t, N > src_thread_data, int32x4_t dst_wave_buffer_resource, index_t dst_thread_addr_offset, index_t dst_wave_addr_offset)
Definition tile/core/arch/amd_buffer_addressing.hpp:1926
CK_TILE_DEVICE_EXTERN bf16x2_t llvm_amdgcn_raw_buffer_atomic_add_bf16x2(bf16x2_t vdata, int32x4_t rsrc, index_t voffset, index_t soffset, index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.atomic.fadd.v2bf16")
CK_TILE_DEVICE void async_buffer_load_dwordxn_v(void *smem, int32x4_t rsrc, index_t voffset, index_t, index_t ioffset, index_t=0, bool_constant< pre_nop >={})
Definition tile/core/arch/amd_buffer_addressing.hpp:1352
_Float16 fp16x4_t
Definition vector_type.hpp:137
CK_TILE_DEVICE_EXTERN int32_t llvm_amdgcn_raw_buffer_load_i32(int32x4_t srsrc, index_t voffset, index_t soffset, index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.i32")
CK_TILE_DEVICE void amd_buffer_atomic_max_impl(const thread_buffer< T, N > src_thread_data, int32x4_t dst_wave_buffer_resource, index_t dst_thread_addr_offset, index_t dst_wave_addr_offset)
Definition tile/core/arch/amd_buffer_addressing.hpp:2466
CK_TILE_DEVICE void amd_buffer_atomic_add_raw(const thread_buffer< T, N > &src_thread_data, T *p_dst_wave, const index_t dst_thread_element_offset, const index_t dst_linear_element_offset, const bool dst_thread_element_valid, const index_t dst_element_space_size, bool_constant< pre_nop >={})
Definition tile/core/arch/amd_buffer_addressing.hpp:2835
CK_TILE_DEVICE_EXTERN int32x2_t llvm_amdgcn_raw_buffer_load_i32x2(int32x4_t srsrc, index_t voffset, index_t soffset, index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v2i32")
constant< v > number
Definition tile/core/numeric/integral_constant.hpp:37
CK_TILE_DEVICE void amd_buffer_atomic_add_impl(const thread_buffer< T, N > &src_thread_data, int32x4_t dst_wave_buffer_resource, index_t dst_thread_addr_offset, index_t dst_wave_addr_offset)
Definition tile/core/arch/amd_buffer_addressing.hpp:2272
constexpr detail::ignore_t ignore
Definition tile/core/utility/ignore.hpp:20
CK_TILE_DEVICE void buffer_store_fence(index_t cnt=0)
Definition tile/core/arch/amd_buffer_addressing.hpp:1063
CK_TILE_DEVICE_EXTERN float llvm_amdgcn_raw_buffer_load_fp32(int32x4_t srsrc, index_t voffset, index_t soffset, index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.f32")
CK_TILE_DEVICE_EXTERN fp16x2_t llvm_amdgcn_raw_buffer_atomic_add_fp16x2(fp16x2_t vdata, int32x4_t rsrc, index_t voffset, index_t soffset, index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.atomic.fadd.v2f16")
bfloat16_t bf16x4_t
Definition vector_type.hpp:146
int32_t int32_t
Definition integer.hpp:10
CK_TILE_DEVICE_EXTERN void llvm_amdgcn_raw_buffer_store_i32x4(int32x4_t vdata, int32x4_t rsrc, index_t voffset, index_t soffset, index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v4i32")
CK_TILE_DEVICE void amd_buffer_load_raw(thread_buffer< T, N > &dst, const T *p_src_wave, index_t src_thread_element_offset, index_t src_linear_element_offset, index_t src_element_space_size, index_t is_valid_element=0, bool_constant< pre_nop >={})
Definition tile/core/arch/amd_buffer_addressing.hpp:2605
bfloat16_t bf16x8_t
Definition vector_type.hpp:147
CK_TILE_DEVICE thread_buffer< T, N > amd_buffer_load_invalid_element_return_zero(const T *p_src_wave, index_t src_thread_element_offset, bool src_thread_element_valid, index_t src_element_space_size)
Definition tile/core/arch/amd_buffer_addressing.hpp:2542
CK_TILE_DEVICE void async_buffer_load_fence(index_t cnt=0)
Definition tile/core/arch/amd_buffer_addressing.hpp:1393
CK_TILE_DEVICE void amd_buffer_store_raw(const thread_buffer< T, N > &src_thread_data, T *p_dst_wave, const index_t dst_thread_element_offset, const index_t dst_linear_element_offset, const bool dst_thread_element_valid, const index_t dst_element_space_size)
Definition tile/core/arch/amd_buffer_addressing.hpp:2779
CK_TILE_DEVICE_EXTERN fp32x4_t llvm_amdgcn_raw_buffer_load_fp32x4(int32x4_t srsrc, index_t voffset, index_t soffset, index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v4f32")
CK_TILE_DEVICE void buffer_load_fence(index_t cnt=0)
Definition tile/core/arch/amd_buffer_addressing.hpp:815
typename impl::ext_vector< T, N >::type ext_vector_t
Definition vector_type.hpp:84
unsigned _BitInt(8) bf8_t
Definition float8.hpp:206
CK_TILE_DEVICE_EXTERN int16_t llvm_amdgcn_raw_buffer_load_i16(int32x4_t srsrc, index_t voffset, index_t soffset, index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.i16")
CK_TILE_DEVICE_EXTERN int16x4_t llvm_amdgcn_raw_buffer_load_i16x4(int32x4_t srsrc, index_t voffset, index_t soffset, index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v4i16")
CK_TILE_DEVICE void amd_buffer_store_impl(const thread_buffer< T, N > src_thread_data, int32x4_t dst_wave_buffer_resource, index_t dst_thread_addr_offset, index_t dst_wave_addr_offset)
Definition tile/core/arch/amd_buffer_addressing.hpp:2026
CK_TILE_DEVICE_EXTERN void llvm_amdgcn_raw_buffer_store_fp32x2(fp32x2_t vdata, int32x4_t rsrc, index_t voffset, index_t soffset, index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v2f32")
float fp32x4_t
Definition vector_type.hpp:128
CK_TILE_DEVICE int32x4_t make_wave_buffer_resource(const void *ptr, uint32_t size=0xffffffff, ForceSGPR={})
Definition tile/core/arch/amd_buffer_addressing.hpp:97
CK_TILE_DEVICE void amd_buffer_atomic_add(const thread_buffer< T, N > &src_thread_data, T *p_dst_wave, const index_t dst_thread_element_offset, const bool dst_thread_element_valid, const index_t dst_element_space_size)
Definition tile/core/arch/amd_buffer_addressing.hpp:2805
uint16_t uint16x4_t
Definition vector_type.hpp:182
float fp32x2_t
Definition pk_fp4.hpp:22
int8_t int8x4_t
Definition vector_type.hpp:191
CK_TILE_DEVICE_EXTERN int16x2_t llvm_amdgcn_raw_buffer_load_i16x2(int32x4_t srsrc, index_t voffset, index_t soffset, index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v2i16")
CK_TILE_DEVICE void amd_async_buffer_load(CK_TILE_LDS_ADDR T *smem, int32x4_t src_wave_buffer_resource, index_t src_thread_addr_offset, index_t src_wave_addr_offset, index_t src_immediate_addr_offset=0, index_t flag=0, bool_constant< oob_conditional_check >={})
Definition tile/core/arch/amd_buffer_addressing.hpp:1882
CK_TILE_DEVICE_EXTERN void llvm_amdgcn_raw_buffer_store_ui16x2(uint16x2_t vdata, int32x4_t rsrc, index_t voffset, index_t soffset, index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v2i16")
int32_t index_t
Definition integer.hpp:9
CK_TILE_DEVICE thread_buffer< T, N > amd_buffer_load_invalid_element_return_customized_value(const T *p_src_wave, index_t src_thread_element_offset, bool src_thread_element_valid, index_t src_element_space_size, T customized_value)
Definition tile/core/arch/amd_buffer_addressing.hpp:2580
int32_t int32x2_t
Definition vector_type.hpp:154
CK_TILE_DEVICE_EXTERN fp16x2_t llvm_amdgcn_raw_buffer_load_fp16x2(int32x4_t srsrc, index_t voffset, index_t soffset, index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v2f16")
CK_TILE_DEVICE_EXTERN _Float16 llvm_amdgcn_raw_buffer_load_fp16(int32x4_t srsrc, index_t voffset, index_t soffset, index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.f16")
CK_TILE_DEVICE_EXTERN void llvm_amdgcn_raw_buffer_store_ui16x4(uint16x4_t vdata, int32x4_t rsrc, index_t voffset, index_t soffset, index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v4i16")
CK_TILE_DEVICE thread_buffer< int8_t, N > amd_buffer_load_impl_with_bytes(int32x4_t src_wave_buffer_resource, index_t src_thread_addr_offset, index_t src_wave_addr_offset)
Definition tile/core/arch/amd_buffer_addressing.hpp:1425
CK_TILE_DEVICE_EXTERN void llvm_amdgcn_raw_buffer_load_lds(int32x4_t rsrc, as3_uint32_ptr lds_ptr, index_t size, index_t voffset, index_t soffset, index_t offset, index_t aux) __asm("llvm.amdgcn.raw.buffer.load.lds")
CK_TILE_DEVICE void amd_async_buffer_load_impl(CK_TILE_LDS_ADDR T *smem, int32x4_t src_wave_buffer_resource, index_t src_thread_addr_offset, index_t src_wave_addr_offset, index_t src_immediate_addr_offset=0, bool_constant< pre_nop >={})
Definition tile/core/arch/amd_buffer_addressing.hpp:1857
CK_TILE_DEVICE_EXTERN int32_t llvm_amdgcn_raw_buffer_atomic_add_i32(int32_t vdata, int32x4_t rsrc, index_t voffset, index_t soffset, index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.atomic.add.i32")
CK_TILE_DEVICE_EXTERN void llvm_amdgcn_raw_buffer_store_i16(int16_t vdata, int32x4_t rsrc, index_t voffset, index_t soffset, index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.i16")
int16_t int16x2_t
Definition vector_type.hpp:172
CK_TILE_DEVICE_EXTERN void llvm_amdgcn_raw_buffer_store_i32(int32_t vdata, int32x4_t rsrc, index_t voffset, index_t soffset, index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.i32")
CK_TILE_DEVICE void amd_buffer_store_raw_impl(const thread_buffer< T, N > &dst_thread_data, int32x4_t dst_wave_buffer_resource, index_t dst_thread_addr_offset, index_t dst_wave_addr_offset, index_t dst_linear_addr_offset, index_t is_valid_element=1)
Definition tile/core/arch/amd_buffer_addressing.hpp:2240
uint32_t uint32x2_t
Definition vector_type.hpp:163
CK_TILE_DEVICE_EXTERN void llvm_amdgcn_raw_buffer_store_i32x2(int32x2_t vdata, int32x4_t rsrc, index_t voffset, index_t soffset, index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v2i32")
CK_TILE_DEVICE_EXTERN void llvm_amdgcn_raw_buffer_store_fp32(float vdata, int32x4_t rsrc, index_t voffset, index_t soffset, index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.f32")
CK_TILE_DEVICE_EXTERN void llvm_amdgcn_raw_buffer_store_ui16(uint16_t vdata, int32x4_t rsrc, index_t voffset, index_t soffset, index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.i16")
CK_TILE_DEVICE_EXTERN int8x4_t llvm_amdgcn_raw_buffer_load_i8x4(int32x4_t srsrc, index_t voffset, index_t soffset, index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v4i8")
CK_TILE_DEVICE void amd_buffer_load_raw_impl(thread_buffer< T, N > &dst, int32x4_t src_wave_buffer_resource, index_t src_thread_addr_offset, index_t src_wave_addr_offset, index_t src_linear_addr_offset, index_t flag=0, bool_constant< pre_nop >={})
Definition tile/core/arch/amd_buffer_addressing.hpp:1818
const GenericPointer< typename T::ValueType > T2 value
Definition pointer.h:1697
signed short int16_t
Definition stdint.h:122
unsigned short uint16_t
Definition stdint.h:125
unsigned int uint32_t
Definition stdint.h:126
unsigned char uint8_t
Definition stdint.h:124
signed char int8_t
Definition stdint.h:121
A fixed-size array container similar to std::array with additional utilities.
Definition tile/core/container/array.hpp:43
CK_TILE_HOST_DEVICE constexpr auto & get()
Definition tile/core/container/array.hpp:101
CK_TILE_DEVICE void operator()(const T &value, int32x4_t res, index_t v_offset, index_t, index_t i_offset, index_t)
Definition tile/core/arch/amd_buffer_addressing.hpp:863
CK_TILE_DEVICE void operator()(const T &value, int32x4_t res, index_t v_offset, index_t, index_t i_offset, index_t flag=1)
Definition tile/core/arch/amd_buffer_addressing.hpp:832
Definition tile/core/arch/amd_buffer_addressing.hpp:826
Definition tile/core/arch/amd_buffer_addressing.hpp:857
CK_TILE_DEVICE void operator()(T &value, int32x4_t res, index_t v_offset, index_t, index_t i_offset, index_t=0, bool_constant< pre_nop >={})
Definition tile/core/arch/amd_buffer_addressing.hpp:166
CK_TILE_DEVICE void operator()(T &value, int32x4_t res, index_t v_offset, index_t, index_t i_offset, index_t=0, bool_constant< pre_nop >={})
Definition tile/core/arch/amd_buffer_addressing.hpp:304
CK_TILE_DEVICE void operator()(T &value, int32x4_t res, index_t v_offset, index_t, index_t i_offset, index_t=0, bool_constant< pre_nop >={})
Definition tile/core/arch/amd_buffer_addressing.hpp:269
CK_TILE_DEVICE void operator()(T &value, int32x4_t res, index_t v_offset, index_t, index_t i_offset, index_t=0, bool_constant< pre_nop >={})
Definition tile/core/arch/amd_buffer_addressing.hpp:234
CK_TILE_DEVICE void operator()(T &value, int32x4_t res, index_t v_offset, index_t, index_t i_offset, index_t=0, bool_constant< pre_nop >={})
Definition tile/core/arch/amd_buffer_addressing.hpp:200
CK_TILE_DEVICE void operator()(T &value, int32x4_t res, index_t v_offset, index_t, index_t i_offset, index_t flag=0, bool_constant< pre_nop >={})
Definition tile/core/arch/amd_buffer_addressing.hpp:359
CK_TILE_DEVICE void operator()(T &value, int32x4_t res, index_t v_offset, index_t, index_t i_offset, index_t flag=0, bool_constant< pre_nop >={})
Definition tile/core/arch/amd_buffer_addressing.hpp:492
CK_TILE_DEVICE void operator()(T &value, int32x4_t res, index_t v_offset, index_t, index_t i_offset, index_t flag=0, bool_constant< pre_nop >={})
Definition tile/core/arch/amd_buffer_addressing.hpp:459
CK_TILE_DEVICE void operator()(T &value, int32x4_t res, index_t v_offset, index_t, index_t i_offset, index_t flag=0, bool_constant< pre_nop >={})
Definition tile/core/arch/amd_buffer_addressing.hpp:426
CK_TILE_DEVICE void operator()(T &value, int32x4_t res, index_t v_offset, index_t, index_t i_offset, index_t flag=0, bool_constant< pre_nop >={})
Definition tile/core/arch/amd_buffer_addressing.hpp:393
Definition tile/core/arch/amd_buffer_addressing.hpp:134
Definition tile/core/arch/amd_buffer_addressing.hpp:131
Definition tile/core/arch/amd_buffer_addressing.hpp:90
const void * ptr
Definition tile/core/arch/amd_buffer_addressing.hpp:91
uint32_t range
Definition tile/core/arch/amd_buffer_addressing.hpp:92
uint32_t config
Definition tile/core/arch/amd_buffer_addressing.hpp:93
CK_TILE_DEVICE void operator()(const T &value, int32x4_t res, index_t v_offset, index_t, index_t i_offset, index_t=1)
Definition tile/core/arch/amd_buffer_addressing.hpp:528
CK_TILE_DEVICE void operator()(const T &value, int32x4_t res, index_t v_offset, index_t, index_t i_offset, index_t=1)
Definition tile/core/arch/amd_buffer_addressing.hpp:632
CK_TILE_DEVICE void operator()(const T &value, int32x4_t res, index_t v_offset, index_t, index_t i_offset, index_t=1)
Definition tile/core/arch/amd_buffer_addressing.hpp:606
CK_TILE_DEVICE void operator()(const T &value, int32x4_t res, index_t v_offset, index_t, index_t i_offset, index_t=1)
Definition tile/core/arch/amd_buffer_addressing.hpp:580
CK_TILE_DEVICE void operator()(const T &value, int32x4_t res, index_t v_offset, index_t, index_t i_offset, index_t=1)
Definition tile/core/arch/amd_buffer_addressing.hpp:554
CK_TILE_DEVICE void operator()(const T &value, int32x4_t res, index_t v_offset, index_t, index_t i_offset, index_t flag=1)
Definition tile/core/arch/amd_buffer_addressing.hpp:677
CK_TILE_DEVICE void operator()(const T &value, int32x4_t res, index_t v_offset, index_t, index_t i_offset, index_t flag=1)
Definition tile/core/arch/amd_buffer_addressing.hpp:790
CK_TILE_DEVICE void operator()(const T &value, int32x4_t res, index_t v_offset, index_t, index_t i_offset, index_t flag=1)
Definition tile/core/arch/amd_buffer_addressing.hpp:762
CK_TILE_DEVICE void operator()(const T &value, int32x4_t res, index_t v_offset, index_t, index_t i_offset, index_t flag=1)
Definition tile/core/arch/amd_buffer_addressing.hpp:734
CK_TILE_DEVICE void operator()(const T &value, int32x4_t res, index_t v_offset, index_t, index_t i_offset, index_t flag=1)
Definition tile/core/arch/amd_buffer_addressing.hpp:705
Definition tile/core/arch/amd_buffer_addressing.hpp:140
Definition tile/core/arch/amd_buffer_addressing.hpp:137
fp32x4_t payload_t
Definition tile/core/arch/amd_buffer_addressing.hpp:115
float payload_t
Definition tile/core/arch/amd_buffer_addressing.hpp:119
float payload_t
Definition tile/core/arch/amd_buffer_addressing.hpp:118
float payload_t
Definition tile/core/arch/amd_buffer_addressing.hpp:117
fp32x2_t payload_t
Definition tile/core/arch/amd_buffer_addressing.hpp:116
Definition tile/core/arch/amd_buffer_addressing.hpp:113
fp32x4_t payload_t
Definition tile/core/arch/amd_buffer_addressing.hpp:884
float payload_t
Definition tile/core/arch/amd_buffer_addressing.hpp:888
float payload_t
Definition tile/core/arch/amd_buffer_addressing.hpp:887
float payload_t
Definition tile/core/arch/amd_buffer_addressing.hpp:886
fp32x2_t payload_t
Definition tile/core/arch/amd_buffer_addressing.hpp:885
Definition tile/core/arch/amd_buffer_addressing.hpp:882
static CK_TILE_HOST_DEVICE constexpr T zero()
Definition tile/core/numeric/numeric.hpp:58
Definition coordinate_transform.hpp:1392
CK_TILE_DEVICE void operator()(T &value, index_t v_offset, index_t i_offset)
Definition tile/core/arch/amd_buffer_addressing.hpp:901
CK_TILE_DEVICE void operator()(T &value, index_t v_offset, index_t i_offset)
Definition tile/core/arch/amd_buffer_addressing.hpp:961
CK_TILE_DEVICE void operator()(T &value, index_t v_offset, index_t i_offset)
Definition tile/core/arch/amd_buffer_addressing.hpp:946
CK_TILE_DEVICE void operator()(T &value, index_t v_offset, index_t i_offset)
Definition tile/core/arch/amd_buffer_addressing.hpp:931
CK_TILE_DEVICE void operator()(T &value, index_t v_offset, index_t i_offset)
Definition tile/core/arch/amd_buffer_addressing.hpp:916
Definition tile/core/arch/amd_buffer_addressing.hpp:895
Definition tile/core/utility/functional.hpp:43
Definition tile/core/utility/debug.hpp:67
uint32_t * as3_uint32_ptr
Definition tile/core/arch/amd_buffer_addressing.hpp:29
#define CK_TILE_ASYNC_LOAD_WITH_INSTR(instr)
#define LIKELY(x)
Definition tile/core/arch/amd_buffer_addressing.hpp:26