Search Options

Results per page
Sort
Preferred Languages
Advance

Results 1 - 10 of 29 for GetBoolAttr (0.14 sec)

  1. tensorflow/compiler/mlir/tf2xla/internal/passes/xla_broadcast.cc

                    builder.getBoolAttr(true));
      ConstOp shape = builder.create<ConstOp>(loc, shape_attr);
      shape->setAttr(kICIWeightDistributionMlirBridgeMarker,
                     builder.getBoolAttr(true));
      FillOp fill = builder.create<FillOp>(loc, shape, zero);
      fill->setAttr(kICIWeightDistributionMlirBridgeMarker,
                    builder.getBoolAttr(true));
      return fill;
    }
    
    Registered: Sun Jun 16 05:45:23 UTC 2024
    - Last Modified: Thu Jun 13 18:52:07 UTC 2024
    - 13.9K bytes
    - Viewed (0)
  2. tensorflow/compiler/mlir/tensorflow/transforms/functional_control_flow_to_cfg.cc

        Type expected = fn_type.getInput(i);
        if (val.getType() != expected) {
          val =
              builder->create<TF::CastOp>(loc, expected, val,
                                          /*Truncate=*/builder->getBoolAttr(false));
        }
        operands.push_back(val);
      }
      return builder->create<func::CallOp>(loc, fn, operands).getOperation();
    }
    
    // Prepares for jump to the given block by introducing necessary tensor_cast
    Registered: Sun Jun 16 05:45:23 UTC 2024
    - Last Modified: Fri Jan 13 11:42:59 UTC 2023
    - 12.2K bytes
    - Viewed (0)
  3. tensorflow/compiler/mlir/lite/stablehlo/transforms/unfuse_batch_norm_pass.cc

                                                   rewriter.getF32Type());
        ::mlir::Value mean = rewriter.create<TF::MeanOp>(
            bn_op.getLoc(), mean_var_type, inputs, reduce_dim_op,
            /*keep_dims=*/rewriter.getBoolAttr(false));
    
        // Compute variance
        Value shape_value =
            getShapeValue(bn_op.getLoc(), bn_op.getOperand(), rewriter);
        auto broadcast_mean = broadcastToFeatureDim(
    Registered: Sun Jun 16 05:45:23 UTC 2024
    - Last Modified: Thu Apr 25 16:01:03 UTC 2024
    - 11.2K bytes
    - Viewed (0)
  4. tensorflow/compiler/mlir/lite/transforms/legalize_tf.cc

        return rewriter.createOrFold<TF::CastOp>(loc, new_type, val,
                                                 rewriter.getBoolAttr(false));
      }
      return rewriter.createOrFold<TF::CastOp>(
          loc, UnrankedTensorType::get(new_ele_type), val,
          rewriter.getBoolAttr(false));
    }
    
    // Utility function to-
    // 1. Create a tfl.const op with an int32_t values, from an MLIR Value, if the
    Registered: Sun Jun 16 05:45:23 UTC 2024
    - Last Modified: Mon May 20 20:06:54 UTC 2024
    - 45.2K bytes
    - Viewed (0)
  5. tensorflow/compiler/mlir/tensorflow/transforms/decompose_reduce_dataset.cc

      // in TensorFlow and allows lowering to V1 control flow for loop
      // parallelization.
      dataset_while->setAttr("_lower_using_switch_merge",
                             builder.getBoolAttr(true));
    
      return dataset_while;
    }
    
    // Populate the cond of `dataset_while`.  The cond body just returns the
    // condition of whether to continue to next iteration.
    Registered: Sun Jun 16 05:45:23 UTC 2024
    - Last Modified: Thu Apr 25 16:01:03 UTC 2024
    - 14K bytes
    - Viewed (0)
  6. tensorflow/compiler/mlir/tensorflow/transforms/einsum.cc

      auto matmul_type =
          RankedTensorType::get(matmul_shape, original_type.getElementType());
      Value out = rewriter.create<TF::BatchMatMulV2Op>(
          op.getLoc(), matmul_type, lhs, rhs, rewriter.getBoolAttr(false),
          rewriter.getBoolAttr(false));
    
      bool out_reshape_need = (reshape_shape.size() != matmul_shape.size() ||
                               original_type.getRank() != matmul_shape.size());
    Registered: Sun Jun 16 05:45:23 UTC 2024
    - Last Modified: Thu Apr 25 16:01:03 UTC 2024
    - 33.3K bytes
    - Viewed (0)
  7. tensorflow/compiler/mlir/lite/transforms/legalize_patterns.td

    def convertIntAttrTo32Bit : NativeCodeCall<
        "$_builder.getI32IntegerAttr($0.cast<IntegerAttr>().getInt())">;
    
    // Builds a constant bool attribute.
    class GetBoolAttr<int value> :
        NativeCodeCall<"$_builder.getBoolAttr(" # value #")">;
    
    // Converts an integer attribute $0 to 64-bit with builder.
    def convertIntAttrTo64Bit : NativeCodeCall<
        "$_builder.getI64IntegerAttr($0.cast<IntegerAttr>().getInt())">;
    
    Registered: Sun Jun 16 05:45:23 UTC 2024
    - Last Modified: Tue Jun 04 13:30:42 UTC 2024
    - 28.5K bytes
    - Viewed (0)
  8. tensorflow/compiler/mlir/quantization/tensorflow/passes/replace_cast_hacks_with_tf_xla_ops.cc

          loc, tensor_type.clone(builder.getIntegerType(32)), tensor);
      auto reduced =
          builder.create<TF::SumOp>(loc, tensor_i32, reduction_indices_value,
                                    /*keep_dims=*/builder.getBoolAttr(true));
      auto mul_op = builder.create<TF::MulOp>(loc, zp, reduced);
    
      SmallVector<Value> folded_results = ConstantFoldOpIfPossible(mul_op);
      return folded_results.front();
    }
    
    Registered: Sun Jun 16 05:45:23 UTC 2024
    - Last Modified: Thu Apr 25 16:01:03 UTC 2024
    - 47.1K bytes
    - Viewed (0)
  9. tensorflow/compiler/mlir/tensorflow/transforms/functional_control_flow_to_regions.cc

        Type expected_type = std::get<1>(ArgAndType);
        if (arg.getType() != expected_type) {
          arg = builder.create<CastOp>(loc, expected_type, arg,
                                       /*Truncate=*/builder.getBoolAttr(false));
        }
        casted_args.push_back(arg);
      }
      auto call = builder.create<func::CallOp>(loc, func, casted_args);
    
      auto results = call.getResults();
      auto block_args = entry->getArguments();
    Registered: Sun Jun 16 05:45:23 UTC 2024
    - Last Modified: Thu Apr 25 16:01:03 UTC 2024
    - 11K bytes
    - Viewed (0)
  10. tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_tf_xla_call_module_to_stablehlo_pass.cc

          if (operand.getType() != expected_type) {
            operand = rewriter.create<TF::CastOp>(
                op.getLoc(), expected_type, operand,
                /*Truncate=*/rewriter.getBoolAttr(false));
          }
          casted_operands.push_back(operand);
        }
    
        auto call = rewriter.create<func::CallOp>(
            op->getLoc(), main_fn.getSymName(), main_fn.getResultTypes(),
            casted_operands);
    Registered: Sun Jun 16 05:45:23 UTC 2024
    - Last Modified: Thu Jan 25 09:43:18 UTC 2024
    - 10.9K bytes
    - Viewed (0)
Back to top