Search Options

Results per page
Sort
Preferred Languages
Advance

Results 1 - 9 of 9 for GetXlaOpsCommonFlags (0.88 sec)

  1. tensorflow/compiler/jit/device_context_test.cc

    #include "tensorflow/core/framework/tensor_testutil.h"
    #include "tsl/lib/core/status_test_util.h"
    
    namespace tensorflow {
    namespace {
    
    static bool Initialized = [] {
      auto& rollout_config = GetXlaOpsCommonFlags()->tf_xla_use_device_api;
      rollout_config.enabled_for_xla_launch_ = true;
      rollout_config.enabled_for_compile_on_demand_ = true;
    
      tensorflow::GetXlaDeviceFlags()->tf_xla_enable_xla_devices = true;
      return true;
    Registered: Sun Jun 16 05:45:23 UTC 2024
    - Last Modified: Wed Sep 06 19:12:29 UTC 2023
    - 3.7K bytes
    - Viewed (0)
  2. tensorflow/compiler/jit/xla_compile_util_test.cc

    }
    
    TEST(XlaCompileUtilTest, PjRtXlaLaunchFlagTest) {
      EXPECT_FALSE(UsePjRtForSingleDeviceCompilation(DeviceType(DEVICE_CPU)));
    
      // Flag is turned on, but no device is allowlisted.
      auto& rollout_config = GetXlaOpsCommonFlags()->tf_xla_use_device_api;
      rollout_config.enabled_for_xla_launch_ = true;
    
      EXPECT_FALSE(UsePjRtForSingleDeviceCompilation(DeviceType(DEVICE_CPU)));
    
    Registered: Sun Jun 16 05:45:23 UTC 2024
    - Last Modified: Tue May 16 21:48:05 UTC 2023
    - 6K bytes
    - Viewed (0)
  3. tensorflow/compiler/jit/kernels/xla_ops.cc

        }
        return GetXlaOpsCommonFlags()->tf_xla_async_compilation
                   ? DeviceCompileMode::kAsync
                   : DeviceCompileMode::kLazy;
      }();
    
      bool use_pjrt =
          GetXlaOpsCommonFlags()
              ->tf_xla_use_device_api.IsEnabledInXlaCompileAndRunForDevice(
                  platform_info_.device_type());
    
      if (GetXlaOpsCommonFlags()->tf_xla_always_defer_compilation ||
    Registered: Sun Jun 16 05:45:23 UTC 2024
    - Last Modified: Fri May 17 22:46:36 UTC 2024
    - 41.4K bytes
    - Viewed (0)
  4. tensorflow/compiler/jit/xla_compile_util.cc

        TF_RETURN_IF_ERROR(status);
      }
      FixupSourceAndSinkEdges(graph.get());
      return graph;
    }
    
    bool UsePjRtForSingleDeviceCompilation(const DeviceType& device_type) {
      const auto& rollout_config = GetXlaOpsCommonFlags()->tf_xla_use_device_api;
      return rollout_config.IsEnabledInXlaLaunchForDevice(device_type) ||
             rollout_config.IsEnabledInXlaCompileOnDemandForDevice(device_type) ||
    Registered: Sun Jun 16 05:45:23 UTC 2024
    - Last Modified: Wed Feb 21 09:53:30 UTC 2024
    - 4.6K bytes
    - Viewed (0)
  5. tensorflow/compiler/jit/xla_launch_util_gpu_test.cc

    }
    
    class PjRtExecutionUtilGpuTest : public OpsTestBase {
     public:
      PjRtExecutionUtilGpuTest() {
        // Set flag to use PJRT for device compilation and execution.
        auto& rollout_config = GetXlaOpsCommonFlags()->tf_xla_use_device_api;
        rollout_config.enabled_for_xla_launch_ = true;
        rollout_config.enabled_for_compile_on_demand_ = true;
        rollout_config.enabled_for_gpu_ = true;
    
    Registered: Sun Jun 16 05:45:23 UTC 2024
    - Last Modified: Wed Sep 06 19:12:29 UTC 2023
    - 10K bytes
    - Viewed (0)
  6. tensorflow/compiler/jit/xla_compile_on_demand_op.cc

      std::vector<const Tensor*> inputs = InputsFromContext(ctx);
      std::vector<int> variable_indices =
          GetResourceVariableIndicesFromContext(ctx);
    
      bool use_pjrt =
          GetXlaOpsCommonFlags()
              ->tf_xla_use_device_api.IsEnabledInXlaCompileOnDemandForDevice(
                  platform_info_.device_type());
      if (use_pjrt) {
        std::vector<VariableInfo> variables;
    Registered: Sun Jun 16 05:45:23 UTC 2024
    - Last Modified: Thu Feb 29 08:39:39 UTC 2024
    - 13.4K bytes
    - Viewed (0)
  7. tensorflow/compiler/jit/flags.h

    MarkForCompilationPassFlags* GetMarkForCompilationPassFlags();
    BuildXlaOpsPassFlags* GetBuildXlaOpsPassFlags();
    XlaSparseCoreFlags* GetXlaSparseCoreFlags();
    XlaDeviceFlags* GetXlaDeviceFlags();
    XlaOpsCommonFlags* GetXlaOpsCommonFlags();
    XlaCallModuleFlags* GetXlaCallModuleFlags();
    
    MlirCommonFlags* GetMlirCommonFlags();
    
    void ResetJitCompilerFlags();
    
    const JitRtFlags& GetJitRtFlags();
    
    Registered: Sun Jun 16 05:45:23 UTC 2024
    - Last Modified: Wed Apr 17 18:52:57 UTC 2024
    - 14.5K bytes
    - Viewed (0)
  8. tensorflow/compiler/jit/flags.cc

      return sparse_core_flags;
    }
    
    XlaDeviceFlags* GetXlaDeviceFlags() {
      absl::call_once(flags_init, &AllocateAndParseFlags);
      return device_flags;
    }
    
    XlaOpsCommonFlags* GetXlaOpsCommonFlags() {
      absl::call_once(flags_init, &AllocateAndParseFlags);
      return ops_flags;
    }
    
    XlaCallModuleFlags* GetXlaCallModuleFlags() {
      absl::call_once(flags_init, &AllocateAndParseFlags);
    Registered: Sun Jun 16 05:45:23 UTC 2024
    - Last Modified: Wed Apr 17 18:52:57 UTC 2024
    - 24.5K bytes
    - Viewed (0)
  9. tensorflow/compiler/jit/xla_launch_util_test.cc

    }
    
    class PjRtExecutionUtilTest : public OpsTestBase {
     public:
      PjRtExecutionUtilTest() {
        // Set flag to use PJRT for device compilation and execution.
        auto& rollout_config = GetXlaOpsCommonFlags()->tf_xla_use_device_api;
        rollout_config.enabled_for_xla_launch_ = true;
        rollout_config.enabled_for_compile_on_demand_ = true;
    
        // Set flag to enable using XLA devices. PJRT currently is only supported
    Registered: Sun Jun 16 05:45:23 UTC 2024
    - Last Modified: Wed Feb 21 09:53:30 UTC 2024
    - 28.8K bytes
    - Viewed (0)
Back to top