Search Options

Results per page
Sort
Preferred Languages
Advance

Results 11 - 20 of 32 for PartitionedCallOp (0.58 sec)

  1. tensorflow/compiler/mlir/lite/stablehlo/transforms/tf_stablehlo_pass.cc

      }
      if (skip_resize_) {
        target.addLegalOp<TF::ResizeBilinearOp>();
        target.addLegalOp<TF::ResizeNearestNeighborOp>();
      }
      if (skip_partitioned_calls_) {
        target.addLegalOp<TF::PartitionedCallOp>();
        target.addLegalOp<TF::StatefulPartitionedCallOp>();
      }
    
      FrozenRewritePatternSet frozen_patterns(std::move(patterns));
      if (failed(applyPartialConversion(func, target, frozen_patterns))) {
    Registered: Sun Jun 16 05:45:23 UTC 2024
    - Last Modified: Tue May 28 21:49:50 UTC 2024
    - 7.5K bytes
    - Viewed (0)
  2. tensorflow/compiler/mlir/quantization/tensorflow/ops/tf_op_quant_spec.cc

    // than function name.
    std::unique_ptr<OpQuantSpec> GetTFOpQuantSpec(Operation* op) {
      auto spec = std::make_unique<OpQuantSpec>();
      if (auto call_op = dyn_cast<TF::PartitionedCallOp>(op)) {
        StringRef function_name =
            mlir::cast<FlatSymbolRefAttr>(call_op.getFAttr()).getValue();
        if (!function_name.starts_with("composite_")) {
          return spec;
        }
    Registered: Sun Jun 16 05:45:23 UTC 2024
    - Last Modified: Thu Apr 25 16:01:03 UTC 2024
    - 6.3K bytes
    - Viewed (0)
  3. tensorflow/compiler/mlir/quantization/tensorflow/passes/lift_hashtable_ops_as_args.cc

      auto function_uses =
          SymbolTable::getSymbolUses(func, &module.getBodyRegion());
      if (!function_uses.has_value()) return false;
      for (auto& function_use : function_uses.value()) {
        if (!llvm::isa<TF::PartitionedCallOp, TF::StatefulPartitionedCallOp>(
                function_use.getUser())) {
          return false;
        }
      }
      return true;
    }
    
    // Returns the `shared_name` attribute value if exists. If not, returns an
    Registered: Sun Jun 16 05:45:23 UTC 2024
    - Last Modified: Fri May 17 17:58:54 UTC 2024
    - 8.2K bytes
    - Viewed (0)
  4. tensorflow/compiler/mlir/quantization/stablehlo/passes/replace_stablehlo_ops_in_main_function_with_xla_call_module_ops.cc

        runOnOperation() {
      ModuleOp module_op = getOperation();
    
      func::FuncOp main_func = FindMainFuncOp(module_op);
      if (!main_func) return;
    
      // In case the model has tf.StatefulPartitionedCallOp or tf.PartitionedCallOp,
      // we recursively find called functions and process StableHLO ops in them.
      SmallVector<func::FuncOp> func_ops;
      func_ops.push_back(main_func);
      int stablehlo_func_id = -1;
      while (!func_ops.empty()) {
    Registered: Sun Jun 16 05:45:23 UTC 2024
    - Last Modified: Thu Apr 25 16:01:03 UTC 2024
    - 21K bytes
    - Viewed (0)
  5. tensorflow/compiler/mlir/tensorflow/transforms/executor_tpuv1_outline_tpu_island.cc

        // The function is in place in the nested module, create a call and yield in
        // the original island.
        OpBuilder builder = OpBuilder::atBlockEnd(&island_op.GetBody());
        auto call_op = builder.create<mlir::TF::PartitionedCallOp>(
            island_op.getLoc(), func_result_types, operands.getArrayRef(),
            SymbolRefAttr::get(
                builder.getContext(), kNestedModule,
    Registered: Sun Jun 16 05:45:23 UTC 2024
    - Last Modified: Thu Apr 25 16:01:03 UTC 2024
    - 8.7K bytes
    - Viewed (0)
  6. tensorflow/compiler/mlir/tensorflow/transforms/prepare_tpu_computation_for_tf_export.cc

    bool SupportsCommunicationComputation(Operation* op) {
      return isa<TF::IfRegionOp, TF::WhileRegionOp, TF::CaseRegionOp,
                 TF::XlaCallModuleOp, TF::StatefulPartitionedCallOp,
                 TF::PartitionedCallOp, TF::LegacyCallOp>(op);
    }
    
    #define GEN_PASS_DEF_PREPARETPUCOMPUTATIONFORTFEXPORTPASS
    #include "tensorflow/compiler/mlir/tensorflow/transforms/tf_passes.h.inc"
    
    class PrepareTpuComputationForTfExportPass
    Registered: Sun Jun 16 05:45:23 UTC 2024
    - Last Modified: Thu Apr 25 16:01:03 UTC 2024
    - 11.8K bytes
    - Viewed (0)
  7. tensorflow/compiler/mlir/quantization/common/attrs_and_constraints.h

        }
      }
      return nullptr;
    }
    
    // Returns the function attribute for the given call op which is lifted for
    // quantization.
    inline FlatSymbolRefAttr GetFuncAttr(TF::PartitionedCallOp call_op) {
      return mlir::dyn_cast<FlatSymbolRefAttr>(call_op.getFAttr());
    }
    
    inline FlatSymbolRefAttr GetFuncAttr(TF::XlaCallModuleOp call_op) {
      return call_op->getAttrOfType<FlatSymbolRefAttr>(
    Registered: Sun Jun 16 05:45:23 UTC 2024
    - Last Modified: Thu Apr 25 16:01:03 UTC 2024
    - 9.9K bytes
    - Viewed (0)
  8. tensorflow/compiler/mlir/lite/quantization/tensorflow/fallback_to_flex_ops.cc

    bool IsAlwaysAllowlistedOp(Operation *op) {
      return llvm::isa<
          // clang-format off
          // go/keep-sorted start
          TF::ConstOp,
          TF::IdentityOp,
          TF::PartitionedCallOp,
          TF::StatefulPartitionedCallOp
          // go/keep-sorted end
          // clang-format on
          >(op);
    }
    
    // LINT.IfChange
    // The list of quantizable ops in the Legacy Integer mode.
    Registered: Sun Jun 16 05:45:23 UTC 2024
    - Last Modified: Thu Apr 25 16:01:03 UTC 2024
    - 12.2K bytes
    - Viewed (0)
  9. tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc

          has_lower_as_multi_device_function_attr = lower.getValue();
        tensorflow::FunctionCallInlinePolicy policy =
            tensorflow::GetFunctionCallInlinePolicy(
                isa<PartitionedCallOp, StatefulPartitionedCallOp>(call),
                has_lower_as_multi_device_function_attr);
    
        if (policy == tensorflow::FunctionCallInlinePolicy::kMultiDevicePlacer)
    Registered: Sun Jun 16 05:45:23 UTC 2024
    - Last Modified: Thu Apr 25 16:01:03 UTC 2024
    - 14.6K bytes
    - Viewed (0)
  10. tensorflow/compiler/mlir/quantization/tensorflow/passes/insert_main_function.cc

        llvm::ArrayRef<Type> new_types = llvm::ArrayRef(
            result_types.begin() + result_idx, func_op.getNumResults());
        result_idx += func_op.getNumResults();
    
        auto call_op = builder.create<TF::PartitionedCallOp>(
            module_op.getLoc(), new_types, new_args,
            SymbolRefAttr::get(context, func_op.getSymName()),
            /*config=*/builder.getStringAttr(""),
            /*config_proto=*/builder.getStringAttr(""),
    Registered: Sun Jun 16 05:45:23 UTC 2024
    - Last Modified: Thu Apr 25 16:01:03 UTC 2024
    - 16.5K bytes
    - Viewed (0)
Back to top