Search Options

Results per page
Sort
Preferred Languages
Advance

Results 1 - 10 of 19 for shape_determination_fns (0.33 sec)

  1. tensorflow/compiler/jit/pjrt_base_device.h

        const XlaShapeLayoutHelpers::ShapeDeterminationFns&
        default_shape_determination_fns() const {
          return shape_determination_fns_.at(0);
        }
    
        const XlaShapeLayoutHelpers::ShapeDeterminationFns&
        shape_determination_fns_at(int i) const {
          return shape_determination_fns_[i];
        }
    
       private:
        const DeviceType jit_device_type_;
        std::vector<XlaShapeLayoutHelpers::ShapeDeterminationFns>
    Registered: Sun Jun 16 05:45:23 UTC 2024
    - Last Modified: Wed Feb 21 12:19:41 UTC 2024
    - 4K bytes
    - Viewed (0)
  2. tensorflow/compiler/jit/xla_compiler_options_util.cc

              << ",graph_def_version=" << options.graph_def_version
              << ",options.shape_determination_fns.layout_preference_fn?="
              << (options.shape_determination_fns.layout_preference_fn != nullptr)
              << ",options.shape_determination_fns.shape_representation_fn?="
              << (options.shape_determination_fns.shape_representation_fn !=
                  nullptr)
    Registered: Sun Jun 16 05:45:23 UTC 2024
    - Last Modified: Wed Feb 21 09:53:30 UTC 2024
    - 6.4K bytes
    - Viewed (0)
  3. tensorflow/compiler/mlir/tf2xla/internal/legalize_tf_mlir.cc

              use_tuple_args, true, shape_determination_fns, compilation_result,
              custom_legalization_passes, metadata.module_name(),
              lower_to_xla_hlo));
    
      // Compute how arguments are shared across different cores.
      auto sharding_result =
          tpu::GetShardingInfo(metadata, arg_shapes, shape_determination_fns,
                               arg_core_mapping, per_core_arg_shapes);
    Registered: Sun Jun 16 05:45:23 UTC 2024
    - Last Modified: Sun Apr 14 20:29:34 UTC 2024
    - 6.1K bytes
    - Viewed (0)
  4. tensorflow/compiler/jit/xla_device_context.h

        return device_to_device_streams_.at(index).get();
      }
      xla::TransferManager* transfer_manager() const { return transfer_manager_; }
      const XlaShapeLayoutHelpers::ShapeDeterminationFns& shape_determination_fns()
          const {
        return shape_determination_fns_;
      }
    
      // Returns a device-to-device stream, in round-robin fashion.
      se::Stream* GetDeviceToDeviceStream();
    
    Registered: Sun Jun 16 05:45:23 UTC 2024
    - Last Modified: Wed Sep 06 19:12:29 UTC 2023
    - 5.1K bytes
    - Viewed (0)
  5. tensorflow/compiler/mlir/tf2xla/api/v2/legalize_tf.cc

        const tpu::TPUCompileMetadataProto& metadata, bool use_tuple_args,
        llvm::StringRef device_type,
        std::vector<std::unique_ptr<mlir::Pass>>& custom_legalization_passes,
        XlaShapeLayoutHelpers::ShapeDeterminationFns shape_determination_fns,
        const std::vector<tensorflow::TensorShape>& arg_shapes,
        std::vector<tpu::ShardingAndIndex>* arg_core_mapping,
        std::vector<std::vector<xla::Shape>>* per_core_arg_shapes,
    Registered: Sun Jun 16 05:45:23 UTC 2024
    - Last Modified: Wed May 29 00:40:46 UTC 2024
    - 6.8K bytes
    - Viewed (0)
  6. tensorflow/compiler/jit/pjrt_device_context.h

    // devices using PjRt.
    class PjRtDeviceContext : public DeviceContext {
     public:
      explicit PjRtDeviceContext(
          XlaShapeLayoutHelpers::ShapeDeterminationFns shape_determination_fns,
          bool use_pjrt_tensor_buffer = false)
          : shape_determination_fns_(std::move(shape_determination_fns)),
            use_pjrt_tensor_buffer_(use_pjrt_tensor_buffer) {}
    
      void CopyCPUTensorToDevice(const Tensor* cpu_tensor, Device* device,
    Registered: Sun Jun 16 05:45:23 UTC 2024
    - Last Modified: Wed Jul 19 19:27:39 UTC 2023
    - 2.7K bytes
    - Viewed (0)
  7. tensorflow/compiler/mlir/tf2xla/api/v1/compile_mlir_util_test.cc

      auto status = CompileSerializedMlirToXlaHlo(
          kMlirModuleStr, arg_shapes, /*device_type=*/"XLA_TPU_JIT",
          /*use_tuple_args=*/true, /*enable_op_fallback=*/false,
          /*shape_determination_fns=*/{}, &compilation_result);
    
      EXPECT_TRUE(status.ok());
      EXPECT_THAT(status.value(), HasSubstr("mhlo.const"));
    }
    
    TEST(LegalizeMlirTest, FailsLegalizesModule) {
      constexpr char failed_legalization[] = R"(
    Registered: Sun Jun 16 05:45:23 UTC 2024
    - Last Modified: Mon Mar 25 19:54:38 UTC 2024
    - 9.7K bytes
    - Viewed (0)
  8. tensorflow/compiler/mlir/tf2xla/internal/legalize_tf_to_hlo.cc

        const tpu::MlirToHloArgs& computation,
        const tpu::TPUCompileMetadataProto& metadata, bool use_tuple_args,
        llvm::StringRef device_type,
        XlaShapeLayoutHelpers::ShapeDeterminationFns shape_determination_fns,
        const std::vector<tensorflow::TensorShape>& arg_shapes,
        std::vector<tpu::ShardingAndIndex>* arg_core_mapping,
        std::vector<std::vector<xla::Shape>>* per_core_arg_shapes,
    Registered: Sun Jun 16 05:45:23 UTC 2024
    - Last Modified: Sun Apr 14 20:29:34 UTC 2024
    - 3.7K bytes
    - Viewed (0)
  9. tensorflow/compiler/mlir/tf2xla/internal/legalize_tf_mlir.h

        bool lower_to_xla_hlo, const tpu::MlirToHloArgs& computation,
        const tpu::TPUCompileMetadataProto& metadata, llvm::StringRef device_type,
        const XlaShapeLayoutHelpers::ShapeDeterminationFns& shape_determination_fns,
        bool use_tuple_args, XlaCompiler::CompilationResult* compilation_result,
        std::vector<std::unique_ptr<mlir::Pass>>& custom_legalization_passes,
        const std::vector<TensorShape>& arg_shapes,
    Registered: Sun Jun 16 05:45:23 UTC 2024
    - Last Modified: Sun Apr 14 20:29:34 UTC 2024
    - 2.8K bytes
    - Viewed (0)
  10. tensorflow/compiler/jit/pjrt_base_device.cc

                                                    options.device_ordinal)),
          metadata_(DeviceType(options.compilation_device_name),
                    options.shape_determination_fns) {
      if (options.shape_determination_fns.empty()) {
        LOG(ERROR) << "shape_representation_fns must be non-empty.";
      }
      VLOG(1) << "Created PJRT base device " << options.compilation_device_name
              << " device_name: " << name();
    Registered: Sun Jun 16 05:45:23 UTC 2024
    - Last Modified: Wed Feb 21 12:19:41 UTC 2024
    - 2.5K bytes
    - Viewed (0)
Back to top