- Sort Score
- Result 10 results
- Languages All
Results 1 - 10 of 18 for shape_determination_fns_ (0.3 sec)
-
tensorflow/compiler/jit/pjrt_base_device.h
return shape_determination_fns_.at(0); } const XlaShapeLayoutHelpers::ShapeDeterminationFns& shape_determination_fns_at(int i) const { return shape_determination_fns_[i]; } private: const DeviceType jit_device_type_; std::vector<XlaShapeLayoutHelpers::ShapeDeterminationFns> shape_determination_fns_; Metadata(const Metadata&) = delete;
Registered: Sun Jun 16 05:45:23 UTC 2024 - Last Modified: Wed Feb 21 12:19:41 UTC 2024 - 4K bytes - Viewed (0) -
tensorflow/compiler/jit/pjrt_device_context.h
// devices using PjRt. class PjRtDeviceContext : public DeviceContext { public: explicit PjRtDeviceContext( XlaShapeLayoutHelpers::ShapeDeterminationFns shape_determination_fns, bool use_pjrt_tensor_buffer = false) : shape_determination_fns_(std::move(shape_determination_fns)), use_pjrt_tensor_buffer_(use_pjrt_tensor_buffer) {} void CopyCPUTensorToDevice(const Tensor* cpu_tensor, Device* device,
Registered: Sun Jun 16 05:45:23 UTC 2024 - Last Modified: Wed Jul 19 19:27:39 UTC 2023 - 2.7K bytes - Viewed (0) -
tensorflow/compiler/jit/xla_device_context.h
return device_to_device_streams_.at(index).get(); } xla::TransferManager* transfer_manager() const { return transfer_manager_; } const XlaShapeLayoutHelpers::ShapeDeterminationFns& shape_determination_fns() const { return shape_determination_fns_; } // Returns a device-to-device stream, in round-robin fashion. se::Stream* GetDeviceToDeviceStream(); Status ThenExecute(Device* device, stream_executor::Stream* stream,
Registered: Sun Jun 16 05:45:23 UTC 2024 - Last Modified: Wed Sep 06 19:12:29 UTC 2023 - 5.1K bytes - Viewed (0) -
tensorflow/compiler/jit/xla_compiler_options_util.cc
<< ",graph_def_version=" << options.graph_def_version << ",options.shape_determination_fns.layout_preference_fn?=" << (options.shape_determination_fns.layout_preference_fn != nullptr) << ",options.shape_determination_fns.shape_representation_fn?=" << (options.shape_determination_fns.shape_representation_fn != nullptr)
Registered: Sun Jun 16 05:45:23 UTC 2024 - Last Modified: Wed Feb 21 09:53:30 UTC 2024 - 6.4K bytes - Viewed (0) -
tensorflow/compiler/mlir/tf2xla/internal/legalize_tf_mlir.cc
use_tuple_args, true, shape_determination_fns, compilation_result, custom_legalization_passes, metadata.module_name(), lower_to_xla_hlo)); // Compute how arguments are shared across different cores. auto sharding_result = tpu::GetShardingInfo(metadata, arg_shapes, shape_determination_fns, arg_core_mapping, per_core_arg_shapes);
Registered: Sun Jun 16 05:45:23 UTC 2024 - Last Modified: Sun Apr 14 20:29:34 UTC 2024 - 6.1K bytes - Viewed (0) -
tensorflow/compiler/mlir/tf2xla/api/v2/legalize_tf.cc
llvm::StringRef device_type, std::vector<std::unique_ptr<mlir::Pass>>& custom_legalization_passes, XlaShapeLayoutHelpers::ShapeDeterminationFns shape_determination_fns, const std::vector<tensorflow::TensorShape>& arg_shapes, std::vector<tpu::ShardingAndIndex>* arg_core_mapping, std::vector<std::vector<xla::Shape>>* per_core_arg_shapes,
Registered: Sun Jun 16 05:45:23 UTC 2024 - Last Modified: Wed May 29 00:40:46 UTC 2024 - 6.8K bytes - Viewed (0) -
tensorflow/compiler/mlir/tf2xla/api/v1/compile_mlir_util_test.cc
auto status = CompileSerializedMlirToXlaHlo( kMlirModuleStr, arg_shapes, /*device_type=*/"XLA_TPU_JIT", /*use_tuple_args=*/true, /*enable_op_fallback=*/false, /*shape_determination_fns=*/{}, &compilation_result); EXPECT_TRUE(status.ok()); EXPECT_THAT(status.value(), HasSubstr("mhlo.const")); } TEST(LegalizeMlirTest, FailsLegalizesModule) { constexpr char failed_legalization[] = R"(
Registered: Sun Jun 16 05:45:23 UTC 2024 - Last Modified: Mon Mar 25 19:54:38 UTC 2024 - 9.7K bytes - Viewed (0) -
tensorflow/compiler/mlir/tf2xla/internal/legalize_tf_to_hlo.cc
const tpu::MlirToHloArgs& computation, const tpu::TPUCompileMetadataProto& metadata, bool use_tuple_args, llvm::StringRef device_type, XlaShapeLayoutHelpers::ShapeDeterminationFns shape_determination_fns, const std::vector<tensorflow::TensorShape>& arg_shapes, std::vector<tpu::ShardingAndIndex>* arg_core_mapping, std::vector<std::vector<xla::Shape>>* per_core_arg_shapes,
Registered: Sun Jun 16 05:45:23 UTC 2024 - Last Modified: Sun Apr 14 20:29:34 UTC 2024 - 3.7K bytes - Viewed (0) -
tensorflow/compiler/jit/pjrt_base_device.cc
options.device_ordinal)), metadata_(DeviceType(options.compilation_device_name), options.shape_determination_fns) { if (options.shape_determination_fns.empty()) { LOG(ERROR) << "shape_representation_fns must be non-empty."; } VLOG(1) << "Created PJRT base device " << options.compilation_device_name << " device_name: " << name();
Registered: Sun Jun 16 05:45:23 UTC 2024 - Last Modified: Wed Feb 21 12:19:41 UTC 2024 - 2.5K bytes - Viewed (0) -
tensorflow/compiler/mlir/tf2xla/internal/legalize_tf_mlir.h
bool lower_to_xla_hlo, const tpu::MlirToHloArgs& computation, const tpu::TPUCompileMetadataProto& metadata, llvm::StringRef device_type, const XlaShapeLayoutHelpers::ShapeDeterminationFns& shape_determination_fns, bool use_tuple_args, XlaCompiler::CompilationResult* compilation_result, std::vector<std::unique_ptr<mlir::Pass>>& custom_legalization_passes, const std::vector<TensorShape>& arg_shapes,
Registered: Sun Jun 16 05:45:23 UTC 2024 - Last Modified: Sun Apr 14 20:29:34 UTC 2024 - 2.8K bytes - Viewed (0)