Program Listing for File util_ptx.hpp¶
↰ Return to documentation for file (hipcub/backend/rocprim/util_ptx.hpp)
/******************************************************************************
* Copyright (c) 2010-2011, Duane Merrill. All rights reserved.
* Copyright (c) 2011-2018, NVIDIA CORPORATION. All rights reserved.
* Modifications Copyright (c) 2017-2020, Advanced Micro Devices, Inc. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the
* names of its contributors may be used to endorse or promote products
* derived from this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
******************************************************************************/
#ifndef HIPCUB_ROCPRIM_UTIL_PTX_HPP_
#define HIPCUB_ROCPRIM_UTIL_PTX_HPP_
#include <cstdint>
#include <type_traits>
#include "../../config.hpp"
#include <rocprim/intrinsics/warp_shuffle.hpp>
#define HIPCUB_WARP_THREADS ::rocprim::warp_size()
#define HIPCUB_DEVICE_WARP_THREADS ::rocprim::device_warp_size()
#define HIPCUB_HOST_WARP_THREADS ::rocprim::host_warp_size()
#define HIPCUB_ARCH 1 // ignored with rocPRIM backend
BEGIN_HIPCUB_NAMESPACE
// Missing compared to CUB:
// * ThreadExit - not supported
// * ThreadTrap - not supported
// * FFMA_RZ, FMUL_RZ - not in CUB public API
// * WARP_SYNC - not supported, not CUB public API
// * CTA_SYNC_AND - not supported, not CUB public API
// * MatchAny - not in CUB public API
//
// Differences:
// * Warp thread masks (when used) are 64-bit unsigned integers
// * member_mask argument is ignored in WARP_[ALL|ANY|BALLOT] funcs
// * Arguments first_lane, last_lane, and member_mask are ignored
// in Shuffle* funcs
// * count in BAR is ignored, BAR works like CTA_SYNC
// ID functions etc.
HIPCUB_DEVICE inline
int RowMajorTid(int block_dim_x, int block_dim_y, int block_dim_z)
{
return ((block_dim_z == 1) ? 0 : (hipThreadIdx_z * block_dim_x * block_dim_y))
+ ((block_dim_y == 1) ? 0 : (hipThreadIdx_y * block_dim_x))
+ hipThreadIdx_x;
}
HIPCUB_DEVICE inline
unsigned int LaneId()
{
return ::rocprim::lane_id();
}
HIPCUB_DEVICE inline
unsigned int WarpId()
{
return ::rocprim::warp_id();
}
template <int LOGICAL_WARP_THREADS, int /* ARCH */ = 0>
HIPCUB_DEVICE inline
uint64_t WarpMask(unsigned int warp_id) {
constexpr bool is_pow_of_two = ::rocprim::detail::is_power_of_two(LOGICAL_WARP_THREADS);
constexpr bool is_arch_warp =
LOGICAL_WARP_THREADS == ::rocprim::device_warp_size();
uint64_t member_mask = uint64_t(-1) >> (64 - LOGICAL_WARP_THREADS);
if (is_pow_of_two && !is_arch_warp) {
member_mask <<= warp_id * LOGICAL_WARP_THREADS;
}
return member_mask;
}
// Returns the warp lane mask of all lanes less than the calling thread
HIPCUB_DEVICE inline
uint64_t LaneMaskLt()
{
return (uint64_t(1) << LaneId()) - 1;
}
// Returns the warp lane mask of all lanes less than or equal to the calling thread
HIPCUB_DEVICE inline
uint64_t LaneMaskLe()
{
return ((uint64_t(1) << LaneId()) << 1) - 1;
}
// Returns the warp lane mask of all lanes greater than the calling thread
HIPCUB_DEVICE inline
uint64_t LaneMaskGt()
{
return uint64_t(-1)^LaneMaskLe();
}
// Returns the warp lane mask of all lanes greater than or equal to the calling thread
HIPCUB_DEVICE inline
uint64_t LaneMaskGe()
{
return uint64_t(-1)^LaneMaskLt();
}
// Shuffle funcs
template <
int LOGICAL_WARP_THREADS,
typename T
>
HIPCUB_DEVICE inline
T ShuffleUp(T input,
int src_offset,
int first_thread,
unsigned int member_mask)
{
// Not supported in rocPRIM.
(void) first_thread;
// Member mask is not supported in rocPRIM, because it's
// not supported in ROCm.
(void) member_mask;
return ::rocprim::warp_shuffle_up(
input, src_offset, LOGICAL_WARP_THREADS
);
}
template <
int LOGICAL_WARP_THREADS,
typename T
>
HIPCUB_DEVICE inline
T ShuffleDown(T input,
int src_offset,
int last_thread,
unsigned int member_mask)
{
// Not supported in rocPRIM.
(void) last_thread;
// Member mask is not supported in rocPRIM, because it's
// not supported in ROCm.
(void) member_mask;
return ::rocprim::warp_shuffle_down(
input, src_offset, LOGICAL_WARP_THREADS
);
}
template <
int LOGICAL_WARP_THREADS,
typename T
>
HIPCUB_DEVICE inline
T ShuffleIndex(T input,
int src_lane,
unsigned int member_mask)
{
// Member mask is not supported in rocPRIM, because it's
// not supported in ROCm.
(void) member_mask;
return ::rocprim::warp_shuffle(
input, src_lane, LOGICAL_WARP_THREADS
);
}
// Other
HIPCUB_DEVICE inline
unsigned int SHR_ADD(unsigned int x,
unsigned int shift,
unsigned int addend)
{
return (x >> shift) + addend;
}
HIPCUB_DEVICE inline
unsigned int SHL_ADD(unsigned int x,
unsigned int shift,
unsigned int addend)
{
return (x << shift) + addend;
}
namespace detail {
template <typename UnsignedBits>
HIPCUB_DEVICE inline
auto unsigned_bit_extract(UnsignedBits source,
unsigned int bit_start,
unsigned int num_bits)
-> typename std::enable_if<sizeof(UnsignedBits) == 8, unsigned int>::type
{
#ifdef __HIP_PLATFORM_AMD__
return __bitextract_u64(source, bit_start, num_bits);
#else
return (source << (64 - bit_start - num_bits)) >> (64 - num_bits);
#endif // __HIP_PLATFORM_AMD__
}
template <typename UnsignedBits>
HIPCUB_DEVICE inline
auto unsigned_bit_extract(UnsignedBits source,
unsigned int bit_start,
unsigned int num_bits)
-> typename std::enable_if<sizeof(UnsignedBits) < 8, unsigned int>::type
{
#ifdef __HIP_PLATFORM_AMD__
return __bitextract_u32(source, bit_start, num_bits);
#else
return (static_cast<unsigned int>(source) << (32 - bit_start - num_bits)) >> (32 - num_bits);
#endif // __HIP_PLATFORM_AMD__
}
} // end namespace detail
// Bitfield-extract.
// Extracts \p num_bits from \p source starting at bit-offset \p bit_start.
// The input \p source may be an 8b, 16b, 32b, or 64b unsigned integer type.
template <typename UnsignedBits>
HIPCUB_DEVICE inline
unsigned int BFE(UnsignedBits source,
unsigned int bit_start,
unsigned int num_bits)
{
static_assert(std::is_unsigned<UnsignedBits>::value, "UnsignedBits must be unsigned");
return detail::unsigned_bit_extract(source, bit_start, num_bits);
}
// Bitfield insert.
// Inserts the \p num_bits least significant bits of \p y into \p x at bit-offset \p bit_start.
HIPCUB_DEVICE inline
void BFI(unsigned int &ret,
unsigned int x,
unsigned int y,
unsigned int bit_start,
unsigned int num_bits)
{
#ifdef __HIP_PLATFORM_AMD__
ret = __bitinsert_u32(x, y, bit_start, num_bits);
#else
x <<= bit_start;
unsigned int MASK_X = ((1 << num_bits) - 1) << bit_start;
unsigned int MASK_Y = ~MASK_X;
ret = (y & MASK_Y) | (x & MASK_X);
#endif // __HIP_PLATFORM_AMD__
}
HIPCUB_DEVICE inline
unsigned int IADD3(unsigned int x, unsigned int y, unsigned int z)
{
return x + y + z;
}
HIPCUB_DEVICE inline
int PRMT(unsigned int a, unsigned int b, unsigned int index)
{
return ::__byte_perm(a, b, index);
}
HIPCUB_DEVICE inline
void BAR(int count)
{
(void) count;
__syncthreads();
}
HIPCUB_DEVICE inline
void CTA_SYNC()
{
__syncthreads();
}
HIPCUB_DEVICE inline
void WARP_SYNC(unsigned int member_mask)
{
(void) member_mask;
::rocprim::wave_barrier();
}
HIPCUB_DEVICE inline
int WARP_ANY(int predicate, uint64_t member_mask)
{
(void) member_mask;
return ::__any(predicate);
}
HIPCUB_DEVICE inline
int WARP_ALL(int predicate, uint64_t member_mask)
{
(void) member_mask;
return ::__all(predicate);
}
HIPCUB_DEVICE inline
int64_t WARP_BALLOT(int predicate, uint64_t member_mask)
{
(void) member_mask;
return __ballot(predicate);
}
END_HIPCUB_NAMESPACE
#endif // HIPCUB_ROCPRIM_UTIL_PTX_HPP_