Search Options

Results per page
Sort
Preferred Languages
Advance

Results 1 - 5 of 5 for PjRtBaseDevice (0.17 sec)

  1. tensorflow/compiler/jit/pjrt_base_device.cc

      VLOG(1) << "Created PJRT base device " << options.compilation_device_name
              << " device_name: " << name();
    }
    
    /*static*/ absl::StatusOr<const PjRtBaseDevice::Metadata*>
    PjRtBaseDevice::GetMetadataFromDevice(DeviceBase* device) {
      PjRtBaseDevice* pjrt_device =
          dynamic_cast<PjRtBaseDevice*>(device->UnderlyingDevice());
      if (pjrt_device == nullptr) {
        return errors::Internal(
    Registered: Sun Jun 16 05:45:23 UTC 2024
    - Last Modified: Wed Feb 21 12:19:41 UTC 2024
    - 2.5K bytes
    - Viewed (0)
  2. tensorflow/compiler/jit/pjrt_base_device.h

    #include "tensorflow/core/framework/device_base.h"
    
    namespace tensorflow {
    
    // tensorflow::PjRtBaseDevice replaces the deprecated tensorflow::XlaDevice.
    // This accelerator agnostic device is mainly used to store metadata.
    class PjRtBaseDevice : public LocalDevice {
     public:
      // Stores metadata about the PjRtBaseDevice.
      class Metadata {
       public:
        Metadata(const DeviceType& jit_device_type,
    Registered: Sun Jun 16 05:45:23 UTC 2024
    - Last Modified: Wed Feb 21 12:19:41 UTC 2024
    - 4K bytes
    - Viewed (0)
  3. tensorflow/compiler/jit/xla_platform_info.h

    // abstraction for normal, XLA devices and devices inheriting from
    // PjRtBaseDevice.
    class XlaPlatformInfo {
     public:
      XlaPlatformInfo() : device_type_("") {}
      XlaPlatformInfo(XlaPlatformInfo&&) = default;
      explicit XlaPlatformInfo(
          const DeviceType device_type, se::Platform::Id platform_id,
          const XlaDevice::Metadata* xla_device_metadata,
          const PjRtBaseDevice::Metadata* pjrt_device_metadata,
    Registered: Sun Jun 16 05:45:23 UTC 2024
    - Last Modified: Wed Feb 21 09:53:30 UTC 2024
    - 7.2K bytes
    - Viewed (0)
  4. tensorflow/compiler/jit/xla_compiler_options_util_test.cc

          GetShapeDeterminationFns(), XlaDevice::PaddedShapeFn(),
          /*use_multiple_streams=*/false);
    }
    
    std::unique_ptr<PjRtBaseDevice::Metadata> CreatePjRtDeviceMetadata(
        DeviceType compilation_device_type) {
      return std::make_unique<PjRtBaseDevice::Metadata>(compilation_device_type,
                                                        GetShapeDeterminationFns());
    }
    
    Registered: Sun Jun 16 05:45:23 UTC 2024
    - Last Modified: Fri Dec 29 01:41:20 UTC 2023
    - 14.8K bytes
    - Viewed (0)
  5. tensorflow/compiler/jit/xla_platform_info.cc

    }
    
    XlaPlatformInfo XlaPlatformInfoFromDevice(DeviceBase* device_base) {
      se::Platform::Id platform_id = nullptr;
      const XlaDevice::Metadata* xla_device_metadata = nullptr;
      const PjRtBaseDevice::Metadata* pjrt_device_metadata = nullptr;
      std::shared_ptr<se::DeviceMemoryAllocator> custom_allocator;
    
      const std::string& device_type = device_base->device_type();
      if (device_type == DEVICE_CPU) {
    Registered: Sun Jun 16 05:45:23 UTC 2024
    - Last Modified: Thu May 02 17:23:27 UTC 2024
    - 17.4K bytes
    - Viewed (0)
Back to top