Search Options

Results per page
Sort
Preferred Languages
Advance

Results 1 - 4 of 4 for getBranchIndex (0.26 sec)

  1. tensorflow/compiler/mlir/tensorflow/transforms/functional_control_flow_to_regions.cc

      return success();
    }
    
    LogicalResult ConvertCaseOp(CaseOp case_op) {
      OpBuilder builder(case_op);
      auto case_region = builder.create<TF::CaseRegionOp>(
          case_op.getLoc(), case_op.getResultTypes(), case_op.getBranchIndex(),
          case_op.getIsStateless(), case_op.getBranches().size());
      CopyDeviceAndUnderscoredAttributes(case_op, case_region);
    
      for (const auto& item : llvm::enumerate(case_region.getBranches())) {
    Registered: Sun Jun 16 05:45:23 UTC 2024
    - Last Modified: Thu Apr 25 16:01:03 UTC 2024
    - 11K bytes
    - Viewed (0)
  2. tensorflow/compiler/mlir/tensorflow/transforms/region_control_flow_to_functional.cc

      // functional op.
      OpBuilder builder(case_region);
      auto case_op = builder.create<CaseOp>(
          case_region.getLoc(), case_region.getResultTypes(),
          case_region.getBranchIndex(), extern_values,
          builder.getArrayAttr(branch_symbols), case_region.getIsStateless());
      CopyAndOverrideAttributes(case_region, case_op, &builder);
    
      // Redirect old results to new results.
    Registered: Sun Jun 16 05:45:23 UTC 2024
    - Last Modified: Thu Apr 25 16:01:03 UTC 2024
    - 28.7K bytes
    - Viewed (0)
  3. tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.cc

    };
    
    LogicalResult FoldConstantCaseOp::matchAndRewrite(
        TF::CaseOp op, PatternRewriter& rewriter) const {
      // Extract the constant cond value.
      DenseIntElementsAttr branch;
      if (!matchPattern(op.getBranchIndex(), m_Constant(&branch))) return failure();
    
      int index = *branch.getValues<int>().begin();
      if (index < 0 || index >= op.num_branches()) index = op.num_branches() - 1;
    
    Registered: Sun Jun 16 05:45:23 UTC 2024
    - Last Modified: Thu Apr 25 16:01:03 UTC 2024
    - 146.7K bytes
    - Viewed (0)
  4. tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf.cc

          // the number of regions as an input along with the operands.
          mhlo_op = rewriter.create<DstOpT>(loc, op.getResultTypes(),
                                            adaptor.getBranchIndex(),
                                            op.getBranches().size());
        } else if constexpr (std::is_same<DstOpT, mhlo::WhileOp>::value) {
          llvm::SmallVector<Type, 4> while_result_types;
    Registered: Sun Jun 16 05:45:23 UTC 2024
    - Last Modified: Tue Jun 11 20:00:43 UTC 2024
    - 291.8K bytes
    - Viewed (0)
Back to top